diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f679fccc..a27f90a78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,12 @@ They are now ignored and will be removed in a future release. - The undocumented return value has been removed. If you need information about the remote server, use `driver.get_server_info()` instead. +- Transaction functions (a.k.a. managed transactions): + The first argument of transaction functions is now a `ManagedTransaction` + object. It behaves exactly like a regular `Transaction` object, except it + does not offer the `commit`, `rollback`, `close`, and `closed` methods. + Those methods would have caused a hard to interpreted error previously. Hence, + they have been removed. ## Version 4.4 diff --git a/bin/make-unasync b/bin/make-unasync index 79d187d9c..bde42d7e0 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -214,6 +214,7 @@ def apply_unasync(files): "_async": "_sync", "mark_async_test": "mark_sync_test", "assert_awaited_once": "assert_called_once", + "assert_awaited_once_with": "assert_called_once_with", } additional_testkit_backend_replacements = {} rules = [ diff --git a/docs/source/api.rst b/docs/source/api.rst index 1dc77f203..bf6059b5e 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -422,16 +422,18 @@ Will result in: *********************** Sessions & Transactions *********************** -All database activity is co-ordinated through two mechanisms: the :class:`neo4j.Session` and the :class:`neo4j.Transaction`. +All database activity is co-ordinated through two mechanisms: +**sessions** (:class:`neo4j.AsyncSession`) and **transactions** +(:class:`neo4j.Transaction`, :class:`neo4j.ManagedTransaction`). -A :class:`neo4j.Session` is a logical container for any number of causally-related transactional units of work. +A **session** is a logical container for any number of causally-related transactional units of work. Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. Sessions provide the top level of containment for database activity. Session creation is a lightweight operation and *sessions are not thread safe*. Connections are drawn from the :class:`neo4j.Driver` connection pool as required. -A :class:`neo4j.Transaction` is a unit of work that is either committed in its entirety or is rolled back on failure. +A **transaction** is a unit of work that is either committed in its entirety or is rolled back on failure. .. _session-construction-ref: @@ -724,7 +726,6 @@ Example: node_id = create_person_node(tx) set_person_name(tx, node_id, name) tx.commit() - tx.close() def create_person_node(tx): query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" @@ -753,6 +754,12 @@ This function is called one or more times, within a configurable time limit, unt Results should be fully consumed within the function and only aggregate or status values should be returned. Returning a live result object would prevent the driver from correctly managing connections and would break retry guarantees. +This function will receive a :class:`neo4j.ManagedTransaction` object as its first parameter. + +.. autoclass:: neo4j.ManagedTransaction + + .. automethod:: run + Example: .. code-block:: python @@ -811,7 +818,7 @@ A :class:`neo4j.Result` is attached to an active connection, through a :class:`n .. automethod:: closed -See https://neo4j.com/docs/driver-manual/current/cypher-workflow/#driver-type-mapping for more about type mapping. +See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. Graph diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 1d613b76c..64e2f3349 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -235,16 +235,18 @@ Will result in: ********************************* AsyncSessions & AsyncTransactions ********************************* -All database activity is co-ordinated through two mechanisms: the :class:`neo4j.AsyncSession` and the :class:`neo4j.AsyncTransaction`. +All database activity is co-ordinated through two mechanisms: +**sessions** (:class:`neo4j.AsyncSession`) and **transactions** +(:class:`neo4j.AsyncTransaction`, :class:`neo4j.AsyncManagedTransaction`). -A :class:`neo4j.AsyncSession` is a logical container for any number of causally-related transactional units of work. +A **session** is a logical container for any number of causally-related transactional units of work. Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. Sessions provide the top level of containment for database activity. -Session creation is a lightweight operation and *sessions cannot be shared between coroutines*. +Session creation is a lightweight operation and *sessions are not thread safe*. Connections are drawn from the :class:`neo4j.AsyncDriver` connection pool as required. -A :class:`neo4j.AsyncTransaction` is a unit of work that is either committed in its entirety or is rolled back on failure. +A **transaction** is a unit of work that is either committed in its entirety or is rolled back on failure. .. _async-session-construction-ref: @@ -417,7 +419,6 @@ Example: node_id = await create_person_node(tx) await set_person_name(tx, node_id, name) await tx.commit() - await tx.close() async def create_person_node(tx): query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" @@ -447,6 +448,12 @@ This function is called one or more times, within a configurable time limit, unt Results should be fully consumed within the function and only aggregate or status values should be returned. Returning a live result object would prevent the driver from correctly managing connections and would break retry guarantees. +This function will receive a :class:`neo4j.AsyncManagedTransaction` object as its first parameter. + +.. autoclass:: neo4j.AsyncManagedTransaction + + .. automethod:: run + Example: .. code-block:: python @@ -505,4 +512,4 @@ A :class:`neo4j.AsyncResult` is attached to an active connection, through a :cla .. automethod:: closed -See https://neo4j.com/docs/driver-manual/current/cypher-workflow/#driver-type-mapping for more about type mapping. +See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 6aeb2f56c..d60674b22 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -23,6 +23,7 @@ "AsyncDriver", "AsyncGraphDatabase", "AsyncNeo4jDriver", + "AsyncManagedTransaction", "AsyncResult", "AsyncSession", "AsyncTransaction", @@ -42,6 +43,7 @@ "IPv4Address", "IPv6Address", "kerberos_auth", + "ManagedTransaction", "Neo4jDriver", "PoolConfig", "Query", @@ -72,6 +74,7 @@ AsyncNeo4jDriver, ) from ._async.work import ( + AsyncManagedTransaction, AsyncResult, AsyncSession, AsyncTransaction, @@ -83,6 +86,7 @@ Neo4jDriver, ) from ._sync.work import ( + ManagedTransaction, Result, Session, Transaction, diff --git a/neo4j/_async/work/__init__.py b/neo4j/_async/work/__init__.py index e48e1c212..da66b6987 100644 --- a/neo4j/_async/work/__init__.py +++ b/neo4j/_async/work/__init__.py @@ -19,14 +19,18 @@ from .session import ( AsyncResult, AsyncSession, - AsyncTransaction, AsyncWorkspace, ) +from .transaction import ( + AsyncManagedTransaction, + AsyncTransaction, +) __all__ = [ "AsyncResult", "AsyncSession", "AsyncTransaction", + "AsyncManagedTransaction", "AsyncWorkspace", ] diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index 081acde63..fd84e7633 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -16,7 +16,6 @@ # limitations under the License. -import asyncio from logging import getLogger from random import random from time import perf_counter @@ -44,7 +43,10 @@ ) from ...work import Query from .result import AsyncResult -from .transaction import AsyncTransaction +from .transaction import ( + AsyncManagedTransaction, + AsyncTransaction, +) from .workspace import AsyncWorkspace @@ -157,8 +159,9 @@ async def close(self): self._state_failed = True if self._transaction: - if self._transaction.closed() is False: - await self._transaction.rollback() # roll back the transaction if it is not closed + if self._transaction._closed() is False: + # roll back the transaction if it is not closed + await self._transaction._rollback() self._transaction = None try: @@ -306,7 +309,7 @@ async def last_bookmarks(self): if self._auto_result: await self._auto_result.consume() - if self._transaction and self._transaction._closed: + if self._transaction and self._transaction._closed(): self._collect_bookmark(self._transaction._bookmark) self._transaction = None @@ -323,10 +326,10 @@ async def _transaction_error_handler(self, _): self._transaction = None await self._disconnect() - async def _open_transaction(self, *, access_mode, metadata=None, + async def _open_transaction(self, *, tx_cls, access_mode, metadata=None, timeout=None): await self._connect(access_mode=access_mode) - self._transaction = AsyncTransaction( + self._transaction = tx_cls( self._connection, self._config.fetch_size, self._transaction_closed_handler, self._transaction_error_handler @@ -372,6 +375,7 @@ async def begin_transaction(self, metadata=None, timeout=None): raise TransactionError("Explicit transaction already open") await self._open_transaction( + tx_cls=AsyncTransaction, access_mode=self._config.default_access_mode, metadata=metadata, timeout=timeout ) @@ -396,6 +400,7 @@ async def _run_transaction( while True: try: await self._open_transaction( + tx_cls=AsyncManagedTransaction, access_mode=access_mode, metadata=metadata, timeout=timeout ) @@ -403,10 +408,10 @@ async def _run_transaction( try: result = await transaction_function(tx, *args, **kwargs) except Exception: - await tx.close() + await tx._close() raise else: - await tx.commit() + await tx._commit() except IncompleteCommit: raise except (ServiceUnavailable, SessionExpired) as error: diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index 83408da41..766f242a9 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -16,6 +16,8 @@ # limitations under the License. +from functools import wraps + from ..._async_compat.util import AsyncUtil from ...data import DataHydrator from ...exceptions import TransactionError @@ -24,17 +26,10 @@ from .result import AsyncResult -class AsyncTransaction: - """ Container for multiple Cypher queries to be executed within a single - context. asynctransactions can be used within a :py:const:`async with` - block where the transaction is committed or rolled back on based on - whether an exception is raised:: - - async with session.begin_transaction() as tx: - ... +__all__ = ("AsyncTransaction", "AsyncManagedTransaction") - """ +class _AsyncTransactionBase: def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -42,22 +37,22 @@ def __init__(self, connection, fetch_size, on_closed, on_error): ) self._bookmark = None self._results = [] - self._closed = False + self._closed_flag = False self._last_error = None self._fetch_size = fetch_size self._on_closed = on_closed self._on_error = on_error - async def __aenter__(self): + async def _enter(self): return self - async def __aexit__(self, exception_type, exception_value, traceback): - if self._closed: + async def _exit(self, exception_type, exception_value, traceback): + if self._closed_flag: return success = not bool(exception_type) if success: - await self.commit() - await self.close() + await self._commit() + await self._close() async def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout @@ -105,14 +100,16 @@ async def run(self, query, parameters=None, **kwparameters): :param parameters: dictionary of parameters :type parameters: dict :param kwparameters: additional keyword parameters - :returns: a new :class:`neo4j.Result` object - :rtype: :class:`neo4j.Result` + + :returns: a new :class:`neo4j.AsyncResult` object + :rtype: :class:`neo4j.AsyncResult` + :raise TransactionError: if the transaction is already closed """ if isinstance(query, Query): raise ValueError("Query object is only supported for session.run") - if self._closed: + if self._closed_flag: raise TransactionError(self, "Transaction closed") if self._last_error: raise TransactionError(self, @@ -136,12 +133,12 @@ async def run(self, query, parameters=None, **kwparameters): return result - async def commit(self): + async def _commit(self): """Mark this transaction as successful and close in order to trigger a COMMIT. :raise TransactionError: if the transaction is already closed """ - if self._closed: + if self._closed_flag: raise TransactionError(self, "Transaction closed") if self._last_error: raise TransactionError(self, @@ -156,17 +153,17 @@ async def commit(self): await self._connection.fetch_all() self._bookmark = metadata.get("bookmark") finally: - self._closed = True + self._closed_flag = True await AsyncUtil.callback(self._on_closed) return self._bookmark - async def rollback(self): + async def _rollback(self): """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK. :raise TransactionError: if the transaction is already closed """ - if self._closed: + if self._closed_flag: raise TransactionError(self, "Transaction closed") metadata = {} @@ -180,20 +177,82 @@ async def rollback(self): await self._connection.send_all() await self._connection.fetch_all() finally: - self._closed = True + self._closed_flag = True await AsyncUtil.callback(self._on_closed) - async def close(self): + async def _close(self): """Close this transaction, triggering a ROLLBACK if not closed. """ - if self._closed: + if self._closed_flag: return - await self.rollback() + await self._rollback() - def closed(self): + def _closed(self): """Indicator to show whether the transaction has been closed. :return: :const:`True` if closed, :const:`False` otherwise. :rtype: bool """ - return self._closed + return self._closed_flag + + +class AsyncTransaction(_AsyncTransactionBase): + """ Container for multiple Cypher queries to be executed within a single + context. asynctransactions can be used within a :py:const:`async with` + block where the transaction is committed or rolled back on based on + whether an exception is raised:: + + async with session.begin_transaction() as tx: + ... + + """ + + @wraps(_AsyncTransactionBase._enter) + async def __aenter__(self): + return await self._enter() + + @wraps(_AsyncTransactionBase._exit) + async def __aexit__(self, exception_type, exception_value, traceback): + await self._exit(exception_type, exception_value, traceback) + + @wraps(_AsyncTransactionBase._commit) + async def commit(self): + return await self._commit() + + @wraps(_AsyncTransactionBase._rollback) + async def rollback(self): + return await self._rollback() + + @wraps(_AsyncTransactionBase._close) + async def close(self): + return await self._close() + + @wraps(_AsyncTransactionBase._closed) + def closed(self): + return self._closed() + + +class AsyncManagedTransaction(_AsyncTransactionBase): + """Transaction object provided to transaction functions. + + Inside a transaction function, the driver is responsible for managing + (committing / rolling back) the transaction. Therefore, + AsyncManagedTransactions don't offer such methods. + Otherwise, they behave like :class:`.AsyncTransaction`. + + * To commit the transaction, + return anything from the transaction function. + * To rollback the transaction, raise any exception. + + Note that transaction functions have to be idempontent (i.e., the result + of running the function once has to be the same as running it any number + of times). This is, because the driver will retry the transaction function + if the error is classified as retriable. + + .. versionadded:: 5.0 + + Prior, transaction functions used :class:`AsyncTransaction` objects, + but would cause hard to interpret errors when managed explicitly + (committed or rolled back by user code). + """ + pass diff --git a/neo4j/_sync/work/__init__.py b/neo4j/_sync/work/__init__.py index 3ceebdb13..9522d7b9a 100644 --- a/neo4j/_sync/work/__init__.py +++ b/neo4j/_sync/work/__init__.py @@ -19,14 +19,18 @@ from .session import ( Result, Session, - Transaction, Workspace, ) +from .transaction import ( + ManagedTransaction, + Transaction, +) __all__ = [ "Result", "Session", "Transaction", + "ManagedTransaction", "Workspace", ] diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index d372ba69b..780befdb6 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -16,7 +16,6 @@ # limitations under the License. -import asyncio from logging import getLogger from random import random from time import perf_counter @@ -44,7 +43,10 @@ ) from ...work import Query from .result import Result -from .transaction import Transaction +from .transaction import ( + ManagedTransaction, + Transaction, +) from .workspace import Workspace @@ -157,8 +159,9 @@ def close(self): self._state_failed = True if self._transaction: - if self._transaction.closed() is False: - self._transaction.rollback() # roll back the transaction if it is not closed + if self._transaction._closed() is False: + # roll back the transaction if it is not closed + self._transaction._rollback() self._transaction = None try: @@ -306,7 +309,7 @@ def last_bookmarks(self): if self._auto_result: self._auto_result.consume() - if self._transaction and self._transaction._closed: + if self._transaction and self._transaction._closed(): self._collect_bookmark(self._transaction._bookmark) self._transaction = None @@ -323,10 +326,10 @@ def _transaction_error_handler(self, _): self._transaction = None self._disconnect() - def _open_transaction(self, *, access_mode, metadata=None, + def _open_transaction(self, *, tx_cls, access_mode, metadata=None, timeout=None): self._connect(access_mode=access_mode) - self._transaction = Transaction( + self._transaction = tx_cls( self._connection, self._config.fetch_size, self._transaction_closed_handler, self._transaction_error_handler @@ -372,6 +375,7 @@ def begin_transaction(self, metadata=None, timeout=None): raise TransactionError("Explicit transaction already open") self._open_transaction( + tx_cls=Transaction, access_mode=self._config.default_access_mode, metadata=metadata, timeout=timeout ) @@ -396,6 +400,7 @@ def _run_transaction( while True: try: self._open_transaction( + tx_cls=ManagedTransaction, access_mode=access_mode, metadata=metadata, timeout=timeout ) @@ -403,10 +408,10 @@ def _run_transaction( try: result = transaction_function(tx, *args, **kwargs) except Exception: - tx.close() + tx._close() raise else: - tx.commit() + tx._commit() except IncompleteCommit: raise except (ServiceUnavailable, SessionExpired) as error: diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py index 916621893..f7d057474 100644 --- a/neo4j/_sync/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -16,6 +16,8 @@ # limitations under the License. +from functools import wraps + from ..._async_compat.util import Util from ...data import DataHydrator from ...exceptions import TransactionError @@ -24,17 +26,10 @@ from .result import Result -class Transaction: - """ Container for multiple Cypher queries to be executed within a single - context. asynctransactions can be used within a :py:const:`with` - block where the transaction is committed or rolled back on based on - whether an exception is raised:: - - with session.begin_transaction() as tx: - ... +__all__ = ("Transaction", "ManagedTransaction") - """ +class _AsyncTransactionBase: def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -42,22 +37,22 @@ def __init__(self, connection, fetch_size, on_closed, on_error): ) self._bookmark = None self._results = [] - self._closed = False + self._closed_flag = False self._last_error = None self._fetch_size = fetch_size self._on_closed = on_closed self._on_error = on_error - def __enter__(self): + def _enter(self): return self - def __exit__(self, exception_type, exception_value, traceback): - if self._closed: + def _exit(self, exception_type, exception_value, traceback): + if self._closed_flag: return success = not bool(exception_type) if success: - self.commit() - self.close() + self._commit() + self._close() def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout @@ -105,14 +100,16 @@ def run(self, query, parameters=None, **kwparameters): :param parameters: dictionary of parameters :type parameters: dict :param kwparameters: additional keyword parameters + :returns: a new :class:`neo4j.Result` object :rtype: :class:`neo4j.Result` + :raise TransactionError: if the transaction is already closed """ if isinstance(query, Query): raise ValueError("Query object is only supported for session.run") - if self._closed: + if self._closed_flag: raise TransactionError(self, "Transaction closed") if self._last_error: raise TransactionError(self, @@ -136,12 +133,12 @@ def run(self, query, parameters=None, **kwparameters): return result - def commit(self): + def _commit(self): """Mark this transaction as successful and close in order to trigger a COMMIT. :raise TransactionError: if the transaction is already closed """ - if self._closed: + if self._closed_flag: raise TransactionError(self, "Transaction closed") if self._last_error: raise TransactionError(self, @@ -156,17 +153,17 @@ def commit(self): self._connection.fetch_all() self._bookmark = metadata.get("bookmark") finally: - self._closed = True + self._closed_flag = True Util.callback(self._on_closed) return self._bookmark - def rollback(self): + def _rollback(self): """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK. :raise TransactionError: if the transaction is already closed """ - if self._closed: + if self._closed_flag: raise TransactionError(self, "Transaction closed") metadata = {} @@ -180,20 +177,82 @@ def rollback(self): self._connection.send_all() self._connection.fetch_all() finally: - self._closed = True + self._closed_flag = True Util.callback(self._on_closed) - def close(self): + def _close(self): """Close this transaction, triggering a ROLLBACK if not closed. """ - if self._closed: + if self._closed_flag: return - self.rollback() + self._rollback() - def closed(self): + def _closed(self): """Indicator to show whether the transaction has been closed. :return: :const:`True` if closed, :const:`False` otherwise. :rtype: bool """ - return self._closed + return self._closed_flag + + +class Transaction(_AsyncTransactionBase): + """ Container for multiple Cypher queries to be executed within a single + context. asynctransactions can be used within a :py:const:`with` + block where the transaction is committed or rolled back on based on + whether an exception is raised:: + + with session.begin_transaction() as tx: + ... + + """ + + @wraps(_AsyncTransactionBase._enter) + def __enter__(self): + return self._enter() + + @wraps(_AsyncTransactionBase._exit) + def __exit__(self, exception_type, exception_value, traceback): + self._exit(exception_type, exception_value, traceback) + + @wraps(_AsyncTransactionBase._commit) + def commit(self): + return self._commit() + + @wraps(_AsyncTransactionBase._rollback) + def rollback(self): + return self._rollback() + + @wraps(_AsyncTransactionBase._close) + def close(self): + return self._close() + + @wraps(_AsyncTransactionBase._closed) + def closed(self): + return self._closed() + + +class ManagedTransaction(_AsyncTransactionBase): + """Transaction object provided to transaction functions. + + Inside a transaction function, the driver is responsible for managing + (committing / rolling back) the transaction. Therefore, + ManagedTransactions don't offer such methods. + Otherwise, they behave like :class:`.Transaction`. + + * To commit the transaction, + return anything from the transaction function. + * To rollback the transaction, raise any exception. + + Note that transaction functions have to be idempontent (i.e., the result + of running the function once has to be the same as running it any number + of times). This is, because the driver will retry the transaction function + if the error is classified as retriable. + + .. versionadded:: 5.0 + + Prior, transaction functions used :class:`Transaction` objects, + but would cause hard to interpret errors when managed explicitly + (committed or rolled back by user code). + """ + pass diff --git a/setup.cfg b/setup.cfg index c63faed9e..421144bec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,6 @@ multi_line_output=3 order_by_type=false remove_redundant_aliases=true use_parentheses=true + +[tool:pytest] +mock_use_standalone_module = true diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index 43134759e..06fb15acf 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -41,6 +41,7 @@ log, ) from ..backend import Request +from ..exceptions import MarkdAsDriverException TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parents[1] @@ -96,18 +97,29 @@ def _exc_stems_from_driver(exc): if DRIVER_PATH in p.parents: return True - async def _handle_driver_exc(self, exc): + async def write_driver_exc(self, exc): log.debug(traceback.format_exc()) - if isinstance(exc, Neo4jError): - msg = "" if exc.message is None else str(exc.message) - else: - msg = str(exc.args[0]) if exc.args else "" key = self.next_key() self.errors[key] = exc - payload = {"id": key, "errorType": str(type(exc)), "msg": msg} - if isinstance(exc, Neo4jError): - payload["code"] = exc.code + + payload = {"id": key, "msg": ""} + + if isinstance(exc, MarkdAsDriverException): + wrapped_exc = exc.wrapped_exc + payload["errorType"] = str(type(wrapped_exc)) + if wrapped_exc.args: + payload["msg"] = str(wrapped_exc.args[0]) + else: + payload["errorType"] = str(type(exc)) + if isinstance(exc, Neo4jError) and exc.message is not None: + payload["msg"] = str(exc.message) + elif exc.args: + payload["msg"] = str(exc.args[0]) + + if isinstance(exc, Neo4jError): + payload["code"] = exc.code + await self.send_response("DriverError", payload) async def _process(self, request): @@ -132,13 +144,13 @@ async def _process(self, request): " request: " + ", ".join(unsused_keys) ) except (Neo4jError, DriverError, UnsupportedServerProduct, - BoltError) as e: - await self._handle_driver_exc(e) + BoltError, MarkdAsDriverException) as e: + await self.write_driver_exc(e) except requests.FrontendError as e: await self.send_response("FrontendError", {"msg": str(e)}) except Exception as e: if self._exc_stems_from_driver(e): - await self._handle_driver_exc(e) + await self.write_driver_exc(e) else: tb = traceback.format_exc() log.error(tb) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index e79d566f0..43506bb9b 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -26,6 +26,7 @@ fromtestkit, totestkit, ) +from ..exceptions import MarkdAsDriverException class FrontendError(Exception): @@ -355,21 +356,36 @@ async def TransactionRun(backend, data): async def TransactionCommit(backend, data): key = data["txId"] tx = backend.transactions[key] - await tx.commit() + try: + commit = tx.commit + except AttributeError as e: + raise MarkdAsDriverException(e) + # raise DriverError("Type does not support commit %s" % type(tx)) + await commit() await backend.send_response("Transaction", {"id": key}) async def TransactionRollback(backend, data): key = data["txId"] tx = backend.transactions[key] - await tx.rollback() + try: + rollback = tx.rollback + except AttributeError as e: + raise MarkdAsDriverException(e) + # raise DriverError("Type does not support rollback %s" % type(tx)) + await rollback() await backend.send_response("Transaction", {"id": key}) async def TransactionClose(backend, data): key = data["txId"] tx = backend.transactions[key] - await tx.close() + try: + close = tx.close + except AttributeError as e: + raise MarkdAsDriverException(e) + # raise DriverError("Type does not support close %s" % type(tx)) + await close() await backend.send_response("Transaction", {"id": key}) diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 05c63cee1..d553089b1 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -41,6 +41,7 @@ log, ) from ..backend import Request +from ..exceptions import MarkdAsDriverException TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parents[1] @@ -96,18 +97,29 @@ def _exc_stems_from_driver(exc): if DRIVER_PATH in p.parents: return True - def _handle_driver_exc(self, exc): + def write_driver_exc(self, exc): log.debug(traceback.format_exc()) - if isinstance(exc, Neo4jError): - msg = "" if exc.message is None else str(exc.message) - else: - msg = str(exc.args[0]) if exc.args else "" key = self.next_key() self.errors[key] = exc - payload = {"id": key, "errorType": str(type(exc)), "msg": msg} - if isinstance(exc, Neo4jError): - payload["code"] = exc.code + + payload = {"id": key, "msg": ""} + + if isinstance(exc, MarkdAsDriverException): + wrapped_exc = exc.wrapped_exc + payload["errorType"] = str(type(wrapped_exc)) + if wrapped_exc.args: + payload["msg"] = str(wrapped_exc.args[0]) + else: + payload["errorType"] = str(type(exc)) + if isinstance(exc, Neo4jError) and exc.message is not None: + payload["msg"] = str(exc.message) + elif exc.args: + payload["msg"] = str(exc.args[0]) + + if isinstance(exc, Neo4jError): + payload["code"] = exc.code + self.send_response("DriverError", payload) def _process(self, request): @@ -132,13 +144,13 @@ def _process(self, request): " request: " + ", ".join(unsused_keys) ) except (Neo4jError, DriverError, UnsupportedServerProduct, - BoltError) as e: - self._handle_driver_exc(e) + BoltError, MarkdAsDriverException) as e: + self.write_driver_exc(e) except requests.FrontendError as e: self.send_response("FrontendError", {"msg": str(e)}) except Exception as e: if self._exc_stems_from_driver(e): - self._handle_driver_exc(e) + self.write_driver_exc(e) else: tb = traceback.format_exc() log.error(tb) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index b67e413c7..89e61d5bb 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -26,6 +26,7 @@ fromtestkit, totestkit, ) +from ..exceptions import MarkdAsDriverException class FrontendError(Exception): @@ -355,21 +356,36 @@ def TransactionRun(backend, data): def TransactionCommit(backend, data): key = data["txId"] tx = backend.transactions[key] - tx.commit() + try: + commit = tx.commit + except AttributeError as e: + raise MarkdAsDriverException(e) + # raise DriverError("Type does not support commit %s" % type(tx)) + commit() backend.send_response("Transaction", {"id": key}) def TransactionRollback(backend, data): key = data["txId"] tx = backend.transactions[key] - tx.rollback() + try: + rollback = tx.rollback + except AttributeError as e: + raise MarkdAsDriverException(e) + # raise DriverError("Type does not support rollback %s" % type(tx)) + rollback() backend.send_response("Transaction", {"id": key}) def TransactionClose(backend, data): key = data["txId"] tx = backend.transactions[key] - tx.close() + try: + close = tx.close + except AttributeError as e: + raise MarkdAsDriverException(e) + # raise DriverError("Type does not support close %s" % type(tx)) + close() backend.send_response("Transaction", {"id": key}) diff --git a/testkitbackend/exceptions.py b/testkitbackend/exceptions.py new file mode 100644 index 000000000..b16625960 --- /dev/null +++ b/testkitbackend/exceptions.py @@ -0,0 +1,25 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MarkdAsDriverException(Exception): + """ + Wrap any error as DriverException + """ + def __init__(self, wrapped_exc): + super().__init__() + self.wrapped_exc = wrapped_exc diff --git a/tests/_async_compat/__init__.py b/tests/_async_compat/__init__.py index 577335802..183152153 100644 --- a/tests/_async_compat/__init__.py +++ b/tests/_async_compat/__init__.py @@ -16,37 +16,13 @@ # limitations under the License. -import sys - - -if sys.version_info >= (3, 8): - from unittest import mock - from unittest.mock import AsyncMockMixin -else: - import mock - from mock.mock import AsyncMockMixin - from .mark_decorator import ( mark_async_test, mark_sync_test, ) -AsyncMagicMock = mock.AsyncMock -MagicMock = mock.MagicMock -Mock = mock.Mock - - -class AsyncMock(AsyncMockMixin, Mock): - pass - - __all__ = [ "mark_async_test", "mark_sync_test", - "AsyncMagicMock", - "AsyncMock", - "MagicMock", - "Mock", - "mock", ] diff --git a/tests/integration/async_/test_custom_ssl_context.py b/tests/integration/async_/test_custom_ssl_context.py index c09fd099a..ed8a9fad3 100644 --- a/tests/integration/async_/test_custom_ssl_context.py +++ b/tests/integration/async_/test_custom_ssl_context.py @@ -21,21 +21,18 @@ from neo4j import AsyncGraphDatabase -from ..._async_compat import ( - mark_async_test, - mock, -) +from ..._async_compat import mark_async_test @mark_async_test -async def test_custom_ssl_context_is_wraps_connection(target, auth): +async def test_custom_ssl_context_is_wraps_connection(target, auth, mocker): class NoNeedToGoFurtherException(Exception): pass def wrap_fail(*_, **__): raise NoNeedToGoFurtherException() - fake_ssl_context = mock.create_autospec(SSLContext) + fake_ssl_context = mocker.create_autospec(SSLContext) fake_ssl_context.wrap_socket.side_effect = wrap_fail fake_ssl_context.wrap_bio.side_effect = wrap_fail driver = AsyncGraphDatabase.neo4j_driver( diff --git a/tests/integration/sync/test_custom_ssl_context.py b/tests/integration/sync/test_custom_ssl_context.py index 7b66db01c..0135d034a 100644 --- a/tests/integration/sync/test_custom_ssl_context.py +++ b/tests/integration/sync/test_custom_ssl_context.py @@ -21,21 +21,18 @@ from neo4j import GraphDatabase -from ..._async_compat import ( - mark_sync_test, - mock, -) +from ..._async_compat import mark_sync_test @mark_sync_test -def test_custom_ssl_context_is_wraps_connection(target, auth): +def test_custom_ssl_context_is_wraps_connection(target, auth, mocker): class NoNeedToGoFurtherException(Exception): pass def wrap_fail(*_, **__): raise NoNeedToGoFurtherException() - fake_ssl_context = mock.create_autospec(SSLContext) + fake_ssl_context = mocker.create_autospec(SSLContext) fake_ssl_context.wrap_socket.side_effect = wrap_fail fake_ssl_context.wrap_bio.side_effect = wrap_fail driver = GraphDatabase.neo4j_driver( diff --git a/tests/requirements.txt b/tests/requirements.txt index d3b598069..297be88b7 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,6 +5,5 @@ pytest-asyncio>=0.16.0 pytest-benchmark>=3.4.1 pytest-cov>=3.0.0 pytest-mock>=3.6.1 -# brings mock.AsyncMock to Python 3.7 (3.8+ ships with built-in support) -mock>=4.0.3; python_version < '3.8' +mock>=4.0.3 teamcity-messages>=1.29 diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 78b3fd0a7..d23465788 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -22,10 +22,7 @@ from neo4j.conf import PoolConfig from neo4j.exceptions import ConfigurationError -from ...._async_compat import ( - AsyncMagicMock, - mark_async_test, -) +from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -99,11 +96,11 @@ async def test_simple_pull(fake_socket): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = AsyncMagicMock() + sockets.client.settimeout = mocker.AsyncMock() await sockets.server.send_message(0x70, { "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 90f79b9cf..17412f228 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -16,8 +16,6 @@ # limitations under the License. -from unittest.mock import MagicMock - import pytest from neo4j._async.io._bolt4 import AsyncBolt4x0 @@ -193,11 +191,11 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.MagicMock() await sockets.server.send_message(0x70, { "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 4a5853352..47667a581 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -21,10 +21,7 @@ from neo4j._async.io._bolt4 import AsyncBolt4x1 from neo4j.conf import PoolConfig -from ...._async_compat import ( - AsyncMagicMock, - mark_async_test, -) +from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -212,11 +209,11 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = AsyncMagicMock() + sockets.client.settimeout = mocker.AsyncMock() await sockets.server.send_message(0x70, { "server": "Neo4j/4.1.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 8ab5b1f03..9a5f3e6da 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -21,10 +21,7 @@ from neo4j._async.io._bolt4 import AsyncBolt4x2 from neo4j.conf import PoolConfig -from ...._async_compat import ( - AsyncMagicMock, - mark_async_test, -) +from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -212,11 +209,11 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = AsyncMagicMock() + sockets.client.settimeout = mocker.AsyncMock() await sockets.server.send_message(0x70, { "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 666cab9f9..3297e59d1 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -23,10 +23,7 @@ from neo4j._async.io._bolt4 import AsyncBolt4x3 from neo4j.conf import PoolConfig -from ...._async_compat import ( - AsyncMagicMock, - mark_async_test, -) +from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -225,16 +222,17 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): )) @mark_async_test async def test_hint_recv_timeout_seconds( - fake_socket_pair, hints, valid, caplog + fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = AsyncMagicMock() + sockets.client.settimeout = mocker.AsyncMock() await sockets.server.send_message( 0x70, {"server": "Neo4j/4.3.0", "hints": hints} ) - connection = AsyncBolt4x3(address, sockets.client, - PoolConfig.max_connection_lifetime) + connection = AsyncBolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) with caplog.at_level(logging.INFO): await connection.hello() if valid: diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index d592e057b..b356657c3 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -17,17 +17,13 @@ import logging -from unittest.mock import MagicMock import pytest from neo4j._async.io._bolt4 import AsyncBolt4x4 from neo4j.conf import PoolConfig -from ...._async_compat import ( - AsyncMagicMock, - mark_async_test, -) +from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -240,11 +236,11 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): )) @mark_async_test async def test_hint_recv_timeout_seconds( - fake_socket_pair, hints, valid, caplog + fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.MagicMock() await sockets.server.send_message( 0x70, {"server": "Neo4j/4.3.4", "hints": hints} ) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 666f93489..004b06180 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -30,11 +30,7 @@ ServiceUnavailable, ) -from ...._async_compat import ( - AsyncMock, - mark_async_test, - mock, -) +from ...._async_compat import mark_async_test class AsyncFakeSocket: @@ -216,18 +212,22 @@ async def test_pool_max_conn_pool_size(pool): @pytest.mark.parametrize("is_reset", (True, False)) @mark_async_test -async def test_pool_reset_when_released(is_reset, pool): +async def test_pool_reset_when_released(is_reset, pool, mocker): address = ("127.0.0.1", 7687) quick_connection_name = AsyncQuickConnection.__name__ - with mock.patch(f"{__name__}.{quick_connection_name}.is_reset", - new_callable=mock.PropertyMock) as is_reset_mock: - with mock.patch(f"{__name__}.{quick_connection_name}.reset", - new_callable=AsyncMock) as reset_mock: - is_reset_mock.return_value = is_reset - connection = await pool._acquire(address, 3, None) - assert isinstance(connection, AsyncQuickConnection) - assert is_reset_mock.call_count == 0 - assert reset_mock.call_count == 0 - await pool.release(connection) - assert is_reset_mock.call_count == 1 - assert reset_mock.call_count == int(not is_reset) + is_reset_mock = mocker.patch( + f"{__name__}.{quick_connection_name}.is_reset", + new_callable=mocker.PropertyMock + ) + reset_mock = mocker.patch( + f"{__name__}.{quick_connection_name}.reset", + new_callable=mocker.AsyncMock + ) + is_reset_mock.return_value = is_reset + connection = await pool._acquire(address, 3, None) + assert isinstance(connection, AsyncQuickConnection) + assert is_reset_mock.call_count == 0 + assert reset_mock.call_count == 0 + await pool.release(connection) + assert is_reset_mock.call_count == 1 + assert reset_mock.call_count == int(not is_reset) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index cbf05f309..455958379 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -16,8 +16,6 @@ # limitations under the License. -from unittest.mock import Mock - import pytest from neo4j import ( @@ -36,11 +34,8 @@ SessionExpired, ) -from ...._async_compat import ( - AsyncMock, - mark_async_test, -) -from ..work import AsyncFakeConnection +from ...._async_compat import mark_async_test +from ..work import async_fake_connection_generator ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -49,12 +44,12 @@ @pytest.fixture() -def opener(): +def opener(async_fake_connection_generator, mocker): async def open_(addr, timeout): - connection = AsyncFakeConnection() + connection = async_fake_connection_generator() connection.addr = addr connection.timeout = timeout - route_mock = AsyncMock() + route_mock = mocker.AsyncMock() route_mock.return_value = [{ "ttl": 1000, "servers": [ @@ -67,7 +62,7 @@ async def open_(addr, timeout): opener_.connections.append(connection) return connection - opener_ = AsyncMock() + opener_ = mocker.AsyncMock() opener_.connections = [] opener_.side_effect = open_ return opener_ diff --git a/tests/unit/async_/test_addressing.py b/tests/unit/async_/test_addressing.py index d0121214a..979de2178 100644 --- a/tests/unit/async_/test_addressing.py +++ b/tests/unit/async_/test_addressing.py @@ -20,7 +20,6 @@ AF_INET, AF_INET6, ) -import unittest.mock as mock import pytest @@ -34,13 +33,6 @@ from ..._async_compat import mark_async_test -mock_socket_ipv4 = mock.Mock() -mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port) - -mock_socket_ipv6 = mock.Mock() -mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id) - - @mark_async_test async def test_address_resolve(): address = Address(("127.0.0.1", 7687)) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 714a486f3..1dea1e1ad 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -33,11 +33,7 @@ ) from neo4j.exceptions import ConfigurationError -from ..._async_compat import ( - mark_async_test, - mock, -) -from .work import AsyncFakeConnection +from ..._async_compat import mark_async_test @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @@ -58,7 +54,8 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke await driver.close() -@pytest.mark.parametrize("protocol", ("neo4j://", "neo4j+s://", "neo4j+ssc://")) +@pytest.mark.parametrize("protocol", + ("neo4j://", "neo4j+s://", "neo4j+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @pytest.mark.parametrize("port", (":1234", "", ":7687")) @@ -175,28 +172,26 @@ async def test_driver_opens_write_session_by_default(uri, mocker): # to get hold of the actual home database (which won't work in this # unittest) async with driver.session(database="foobar") as session: - with mock.patch.object( - session._pool, "acquire", autospec=True - ) as acquire_mock: - with mock.patch.object( - AsyncTransaction, "_begin", autospec=True - ) as tx_begin_mock: - tx = await session.begin_transaction() - acquire_mock.assert_called_once_with( - access_mode=WRITE_ACCESS, - timeout=mocker.ANY, - database=mocker.ANY, - bookmarks=mocker.ANY - ) - tx_begin_mock.assert_called_once_with( - tx, - mocker.ANY, - mocker.ANY, - mocker.ANY, - WRITE_ACCESS, - mocker.ANY, - mocker.ANY - ) + acquire_mock = mocker.patch.object(session._pool, "acquire", + autospec=True) + tx_begin_mock = mocker.patch.object(AsyncTransaction, "_begin", + autospec=True) + tx = await session.begin_transaction() + acquire_mock.assert_called_once_with( + access_mode=WRITE_ACCESS, + timeout=mocker.ANY, + database=mocker.ANY, + bookmarks=mocker.ANY + ) + tx_begin_mock.assert_called_once_with( + tx, + mocker.ANY, + mocker.ANY, + mocker.ANY, + WRITE_ACCESS, + mocker.ANY, + mocker.ANY + ) await driver.close() @@ -206,12 +201,12 @@ async def test_driver_opens_write_session_by_default(uri, mocker): "neo4j://127.0.0.1:9000", )) @mark_async_test -async def test_verify_connectivity(uri): +async def test_verify_connectivity(uri, mocker): driver = AsyncGraphDatabase.driver(uri) + pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: - with mock.patch.object(driver, "_pool", autospec=True) as pool_mock: - ret = await driver.verify_connectivity() + ret = await driver.verify_connectivity() finally: await driver.close() @@ -231,12 +226,13 @@ async def test_verify_connectivity(uri): {"fetch_size": 69}, )) @mark_async_test -async def test_verify_connectivity_parameters_are_deprecated(uri, kwargs): +async def test_verify_connectivity_parameters_are_deprecated(uri, kwargs, + mocker): driver = AsyncGraphDatabase.driver(uri) + mocker.patch.object(driver, "_pool", autospec=True) try: - with mock.patch.object(driver, "_pool", autospec=True): - with pytest.warns(DeprecationWarning, match="configuration"): - await driver.verify_connectivity(**kwargs) + with pytest.warns(DeprecationWarning, match="configuration"): + await driver.verify_connectivity(**kwargs) finally: await driver.close() diff --git a/tests/unit/async_/work/__init__.py b/tests/unit/async_/work/__init__.py index 3bfbf0ed6..6b67846a2 100644 --- a/tests/unit/async_/work/__init__.py +++ b/tests/unit/async_/work/__init__.py @@ -18,5 +18,5 @@ from ._fake_connection import ( async_fake_connection, - AsyncFakeConnection, + async_fake_connection_generator, ) diff --git a/tests/unit/async_/work/_fake_connection.py b/tests/unit/async_/work/_fake_connection.py index fa62eb562..2ba962ad3 100644 --- a/tests/unit/async_/work/_fake_connection.py +++ b/tests/unit/async_/work/_fake_connection.py @@ -23,89 +23,90 @@ from neo4j import ServerInfo from neo4j._async.io import AsyncBolt -from ...._async_compat import ( - AsyncMock, - mock, - Mock, -) - - -class AsyncFakeConnection(mock.NonCallableMagicMock): - callbacks = [] - server_info = ServerInfo("127.0.0.1", (4, 3)) - - def __init__(self, *args, **kwargs): - kwargs["spec"] = AsyncBolt - super().__init__(*args, **kwargs) - self.attach_mock(Mock(return_value=True), "is_reset_mock") - self.attach_mock(Mock(return_value=False), "defunct") - self.attach_mock(Mock(return_value=False), "stale") - self.attach_mock(Mock(return_value=False), "closed") - self.attach_mock(Mock(), "unresolved_address") - - def close_side_effect(): - self.closed.return_value = True - - self.attach_mock(AsyncMock(side_effect=close_side_effect), - "close") - - @property - def is_reset(self): - if self.closed.return_value or self.defunct.return_value: - raise AssertionError( - "is_reset should not be called on a closed or defunct " - "connection." - ) - return self.is_reset_mock() - - async def fetch_message(self, *args, **kwargs): - if self.callbacks: - cb = self.callbacks.pop(0) - await cb() - return await super().__getattr__("fetch_message")(*args, **kwargs) - - async def fetch_all(self, *args, **kwargs): - while self.callbacks: - cb = self.callbacks.pop(0) - cb() - return await super().__getattr__("fetch_all")(*args, **kwargs) - - def __getattr__(self, name): - parent = super() - - def build_message_handler(name): - def func(*args, **kwargs): - async def callback(): - for cb_name, param_count in ( - ("on_success", 1), - ("on_summary", 0) - ): - cb = kwargs.get(cb_name, None) - if callable(cb): - try: - param_count = \ - len(inspect.signature(cb).parameters) - except ValueError: - # e.g. built-in method as cb - pass - if param_count == 1: - res = cb({}) - else: - res = cb() - try: - await res # maybe the callback is async - except TypeError: - pass # or maybe it wasn't ;) - self.callbacks.append(callback) - - return func - - method_mock = parent.__getattr__(name) - if name in ("run", "commit", "pull", "rollback", "discard"): - method_mock.side_effect = build_message_handler(name) - return method_mock + +@pytest.fixture +def async_fake_connection_generator(session_mocker): + mock = session_mocker.mock_module + + class AsyncFakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def __init__(self, *args, **kwargs): + kwargs["spec"] = AsyncBolt + super().__init__(*args, **kwargs) + self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(mock.Mock(), "unresolved_address") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.AsyncMock(side_effect=close_side_effect), + "close") + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError( + "is_reset should not be called on a closed or defunct " + "connection." + ) + return self.is_reset_mock() + + async def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + await cb() + return await super().__getattr__("fetch_message")(*args, **kwargs) + + async def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + await cb() + return await super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + async def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + res = cb({}) + else: + res = cb() + try: + await res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return AsyncFakeConnection @pytest.fixture -def async_fake_connection(): - return AsyncFakeConnection() +def async_fake_connection(async_fake_connection_generator): + return async_fake_connection_generator() diff --git a/tests/unit/async_/work/conftest.py b/tests/unit/async_/work/conftest.py new file mode 100644 index 000000000..6224f9c67 --- /dev/null +++ b/tests/unit/async_/work/conftest.py @@ -0,0 +1,4 @@ +from ._fake_connection import ( + async_fake_connection, + async_fake_connection_generator, +) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 37f8a22b2..b152de7f3 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -21,6 +21,7 @@ import pytest from neo4j import ( + AsyncManagedTransaction, AsyncSession, AsyncTransaction, Bookmarks, @@ -29,28 +30,24 @@ ) from neo4j._async.io._pool import AsyncIOPool -from ...._async_compat import ( - AsyncMock, - mark_async_test, - mock, -) -from ._fake_connection import AsyncFakeConnection +from ...._async_compat import mark_async_test +from ._fake_connection import async_fake_connection_generator @pytest.fixture() -def pool(): - pool = AsyncMock(spec=AsyncIOPool) - pool.acquire.side_effect = iter(AsyncFakeConnection, 0) +def pool(async_fake_connection_generator, mocker): + pool = mocker.AsyncMock(spec=AsyncIOPool) + pool.acquire.side_effect = iter(async_fake_connection_generator, 0) return pool @mark_async_test -async def test_session_context_calls_close(): +async def test_session_context_calls_close(mocker): s = AsyncSession(None, SessionConfig()) - with mock.patch.object(s, 'close', autospec=True) as mock_close: - async with s: - pass - mock_close.assert_called_once_with() + mock_close = mocker.patch.object(s, 'close', autospec=True) + async with s: + pass + mock_close.assert_called_once_with() @pytest.mark.parametrize("test_run_args", ( @@ -239,10 +236,10 @@ async def test_session_run_wrong_types(pool, query, error_type): @mark_async_test async def test_tx_function_argument_type(pool, tx_type): async def work(tx): - assert isinstance(tx, AsyncTransaction) + assert isinstance(tx, AsyncManagedTransaction) async with AsyncSession(pool, SessionConfig()) as session: - getattr(session, tx_type)(work) + await getattr(session, tx_type)(work) @pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) @@ -257,10 +254,10 @@ async def work(tx): async def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): @unit_of_work(**decorator_kwargs) async def work(tx): - assert isinstance(tx, AsyncTransaction) + assert isinstance(tx, AsyncManagedTransaction) async with AsyncSession(pool, SessionConfig()) as session: - getattr(session, tx_type)(work) + await getattr(session, tx_type)(work) @mark_async_test diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 6eafbb1a5..519d34233 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -22,11 +22,11 @@ import pytest from neo4j import ( + AsyncTransaction, Query, - Transaction, ) -from ._fake_connection import async_fake_connection +from ...._async_compat import mark_async_test @pytest.mark.parametrize(("explicit_commit", "close"), ( @@ -34,25 +34,27 @@ (True, False), (True, True), )) -def test_transaction_context_when_committing(mocker, async_fake_connection, - explicit_commit, close): - on_closed = MagicMock() - on_error = MagicMock() - tx = Transaction(async_fake_connection, 2, on_closed, on_error) - mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) - mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) - with tx as tx_: +@mark_async_test +async def test_transaction_context_when_committing( + mocker, async_fake_connection, explicit_commit, close +): + on_closed = mocker.AsyncMock() + on_error = mocker.AsyncMock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) + mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) + async with tx as tx_: assert mock_commit.call_count == 0 assert mock_rollback.call_count == 0 assert tx is tx_ if explicit_commit: - tx_.commit() - mock_commit.assert_called_once_with() - assert tx.closed() + await tx_.commit() + mock_commit.assert_awaited_once_with() + assert tx_.closed() if close: - tx_.close() + await tx_.close() assert tx_.closed() - mock_commit.assert_called_once_with() + mock_commit.assert_awaited_once_with() assert mock_rollback.call_count == 0 assert tx_.closed() @@ -62,47 +64,52 @@ def test_transaction_context_when_committing(mocker, async_fake_connection, (False, True), (True, True), )) -def test_transaction_context_with_explicit_rollback(mocker, async_fake_connection, - rollback, close): - on_closed = MagicMock() - on_error = MagicMock() - tx = Transaction(async_fake_connection, 2, on_closed, on_error) - mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) - mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) - with tx as tx_: +@mark_async_test +async def test_transaction_context_with_explicit_rollback( + mocker, async_fake_connection, rollback, close +): + on_closed = mocker.AsyncMock() + on_error = mocker.AsyncMock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) + mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) + async with tx as tx_: assert mock_commit.call_count == 0 assert mock_rollback.call_count == 0 assert tx is tx_ if rollback: - tx_.rollback() - mock_rollback.assert_called_once_with() + await tx_.rollback() + mock_rollback.assert_awaited_once_with() assert tx_.closed() if close: - tx_.close() - mock_rollback.assert_called_once_with() + await tx_.close() + mock_rollback.assert_awaited_once_with() assert tx_.closed() assert mock_commit.call_count == 0 - mock_rollback.assert_called_once_with() + mock_rollback.assert_awaited_once_with() assert tx_.closed() -def test_transaction_context_calls_rollback_on_error(mocker, async_fake_connection): +@mark_async_test +async def test_transaction_context_calls_rollback_on_error( + mocker, async_fake_connection +): class OopsError(RuntimeError): pass on_closed = MagicMock() on_error = MagicMock() - tx = Transaction(async_fake_connection, 2, on_closed, on_error) - mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) - mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) + mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with pytest.raises(OopsError): - with tx as tx_: + async with tx as tx_: assert mock_commit.call_count == 0 assert mock_rollback.call_count == 0 assert tx is tx_ raise OopsError assert mock_commit.call_count == 0 - mock_rollback.assert_called_once_with() + mock_rollback.assert_awaited_once_with() assert tx_.closed() @@ -112,74 +119,93 @@ class OopsError(RuntimeError): ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), ({"x": uuid4()}, TypeError), )) -def test_transaction_run_with_invalid_parameters(async_fake_connection, parameters, - error_type): +@mark_async_test +async def test_transaction_run_with_invalid_parameters( + async_fake_connection, parameters, error_type +): on_closed = MagicMock() on_error = MagicMock() - tx = Transaction(async_fake_connection, 2, on_closed, on_error) + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) with pytest.raises(error_type): - tx.run("RETURN $x", **parameters) + await tx.run("RETURN $x", **parameters) -def test_transaction_run_takes_no_query_object(async_fake_connection): +@mark_async_test +async def test_transaction_run_takes_no_query_object(async_fake_connection): on_closed = MagicMock() on_error = MagicMock() - tx = Transaction(async_fake_connection, 2, on_closed, on_error) + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) with pytest.raises(ValueError): - tx.run(Query("RETURN 1")) - - -def test_transaction_rollbacks_on_open_connections(async_fake_connection): - tx = Transaction(async_fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) - with tx as tx_: + await tx.run(Query("RETURN 1")) + + +@mark_async_test +async def test_transaction_rollbacks_on_open_connections( + async_fake_connection +): + tx = AsyncTransaction( + async_fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) + async with tx as tx_: async_fake_connection.is_reset_mock.return_value = False async_fake_connection.is_reset_mock.reset_mock() - tx_.rollback() + await tx_.rollback() async_fake_connection.is_reset_mock.assert_called_once() async_fake_connection.reset.assert_not_called() async_fake_connection.rollback.assert_called_once() -def test_transaction_no_rollback_on_reset_connections(async_fake_connection): - tx = Transaction(async_fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) - with tx as tx_: +@mark_async_test +async def test_transaction_no_rollback_on_reset_connections( + async_fake_connection +): + tx = AsyncTransaction( + async_fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) + async with tx as tx_: async_fake_connection.is_reset_mock.return_value = True async_fake_connection.is_reset_mock.reset_mock() - tx_.rollback() + await tx_.rollback() async_fake_connection.is_reset_mock.assert_called_once() - async_fake_connection.reset.asset_not_called() - async_fake_connection.rollback.asset_not_called() - - -def test_transaction_no_rollback_on_closed_connections(async_fake_connection): - tx = Transaction(async_fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) - with tx as tx_: + async_fake_connection.reset.assert_not_called() + async_fake_connection.rollback.assert_not_called() + + +@mark_async_test +async def test_transaction_no_rollback_on_closed_connections( + async_fake_connection +): + tx = AsyncTransaction( + async_fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) + async with tx as tx_: async_fake_connection.closed.return_value = True async_fake_connection.closed.reset_mock() async_fake_connection.is_reset_mock.reset_mock() - tx_.rollback() + await tx_.rollback() async_fake_connection.closed.assert_called_once() - async_fake_connection.is_reset_mock.asset_not_called() - async_fake_connection.reset.asset_not_called() - async_fake_connection.rollback.asset_not_called() - - -def test_transaction_no_rollback_on_defunct_connections(async_fake_connection): - tx = Transaction(async_fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) - with tx as tx_: + async_fake_connection.is_reset_mock.assert_not_called() + async_fake_connection.reset.assert_not_called() + async_fake_connection.rollback.assert_not_called() + + +@mark_async_test +async def test_transaction_no_rollback_on_defunct_connections( + async_fake_connection +): + tx = AsyncTransaction( + async_fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) + async with tx as tx_: async_fake_connection.defunct.return_value = True async_fake_connection.defunct.reset_mock() async_fake_connection.is_reset_mock.reset_mock() - tx_.rollback() + await tx_.rollback() async_fake_connection.defunct.assert_called_once() - async_fake_connection.is_reset_mock.asset_not_called() - async_fake_connection.reset.asset_not_called() - async_fake_connection.rollback.asset_not_called() + async_fake_connection.is_reset_mock.assert_not_called() + async_fake_connection.reset.assert_not_called() + async_fake_connection.rollback.assert_not_called() diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 7f7577259..34f5f8d78 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -29,24 +29,10 @@ Thread, ) import time -from unittest import ( - mock, - TestCase, -) import pytest -from neo4j import ( - Config, - PoolConfig, - WorkspaceConfig, -) -from neo4j._async.io import AsyncBolt -from neo4j._async.io._pool import AsyncIOPool -from neo4j.exceptions import ( - ClientError, - ServiceUnavailable, -) +from neo4j.exceptions import ClientError from ...async_.io.test_direct import AsyncFakeBoltPool from ...sync.io.test_direct import FakeBoltPool diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 7a1060f50..2b191eef0 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -22,10 +22,7 @@ from neo4j.conf import PoolConfig from neo4j.exceptions import ConfigurationError -from ...._async_compat import ( - MagicMock, - mark_sync_test, -) +from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -99,11 +96,11 @@ def test_simple_pull(fake_socket): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.Mock() sockets.server.send_message(0x70, { "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index cfa60f723..f1a6cf0f5 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -16,8 +16,6 @@ # limitations under the License. -from unittest.mock import MagicMock - import pytest from neo4j._sync.io._bolt4 import Bolt4x0 @@ -193,11 +191,11 @@ def test_n_and_qid_extras_in_pull(fake_socket): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.MagicMock() sockets.server.send_message(0x70, { "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 4cdf0edbd..473e9926f 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -21,10 +21,7 @@ from neo4j._sync.io._bolt4 import Bolt4x1 from neo4j.conf import PoolConfig -from ...._async_compat import ( - MagicMock, - mark_sync_test, -) +from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -212,11 +209,11 @@ def test_hello_passes_routing_metadata(fake_socket_pair): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.Mock() sockets.server.send_message(0x70, { "server": "Neo4j/4.1.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 4c09a9def..1eef06d84 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -21,10 +21,7 @@ from neo4j._sync.io._bolt4 import Bolt4x2 from neo4j.conf import PoolConfig -from ...._async_compat import ( - MagicMock, - mark_sync_test, -) +from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -212,11 +209,11 @@ def test_hello_passes_routing_metadata(fake_socket_pair): @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( - fake_socket_pair, recv_timeout + fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.Mock() sockets.server.send_message(0x70, { "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index d83765f63..6eec47107 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -23,10 +23,7 @@ from neo4j._sync.io._bolt4 import Bolt4x3 from neo4j.conf import PoolConfig -from ...._async_compat import ( - MagicMock, - mark_sync_test, -) +from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -225,16 +222,17 @@ def test_hello_passes_routing_metadata(fake_socket_pair): )) @mark_sync_test def test_hint_recv_timeout_seconds( - fake_socket_pair, hints, valid, caplog + fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.Mock() sockets.server.send_message( 0x70, {"server": "Neo4j/4.3.0", "hints": hints} ) - connection = Bolt4x3(address, sockets.client, - PoolConfig.max_connection_lifetime) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) with caplog.at_level(logging.INFO): connection.hello() if valid: diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 578a960bf..ea0a605ab 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -17,17 +17,13 @@ import logging -from unittest.mock import MagicMock import pytest from neo4j._sync.io._bolt4 import Bolt4x4 from neo4j.conf import PoolConfig -from ...._async_compat import ( - MagicMock, - mark_sync_test, -) +from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) @@ -240,11 +236,11 @@ def test_hello_passes_routing_metadata(fake_socket_pair): )) @mark_sync_test def test_hint_recv_timeout_seconds( - fake_socket_pair, hints, valid, caplog + fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.client.settimeout = MagicMock() + sockets.client.settimeout = mocker.MagicMock() sockets.server.send_message( 0x70, {"server": "Neo4j/4.3.4", "hints": hints} ) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 0d8d6ad88..eeb34a983 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -30,11 +30,7 @@ ServiceUnavailable, ) -from ...._async_compat import ( - mark_sync_test, - Mock, - mock, -) +from ...._async_compat import mark_sync_test class FakeSocket: @@ -216,18 +212,22 @@ def test_pool_max_conn_pool_size(pool): @pytest.mark.parametrize("is_reset", (True, False)) @mark_sync_test -def test_pool_reset_when_released(is_reset, pool): +def test_pool_reset_when_released(is_reset, pool, mocker): address = ("127.0.0.1", 7687) quick_connection_name = QuickConnection.__name__ - with mock.patch(f"{__name__}.{quick_connection_name}.is_reset", - new_callable=mock.PropertyMock) as is_reset_mock: - with mock.patch(f"{__name__}.{quick_connection_name}.reset", - new_callable=Mock) as reset_mock: - is_reset_mock.return_value = is_reset - connection = pool._acquire(address, 3, None) - assert isinstance(connection, QuickConnection) - assert is_reset_mock.call_count == 0 - assert reset_mock.call_count == 0 - pool.release(connection) - assert is_reset_mock.call_count == 1 - assert reset_mock.call_count == int(not is_reset) + is_reset_mock = mocker.patch( + f"{__name__}.{quick_connection_name}.is_reset", + new_callable=mocker.PropertyMock + ) + reset_mock = mocker.patch( + f"{__name__}.{quick_connection_name}.reset", + new_callable=mocker.Mock + ) + is_reset_mock.return_value = is_reset + connection = pool._acquire(address, 3, None) + assert isinstance(connection, QuickConnection) + assert is_reset_mock.call_count == 0 + assert reset_mock.call_count == 0 + pool.release(connection) + assert is_reset_mock.call_count == 1 + assert reset_mock.call_count == int(not is_reset) diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 4d7b17380..0b00a6aa0 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -16,8 +16,6 @@ # limitations under the License. -from unittest.mock import Mock - import pytest from neo4j import ( @@ -36,11 +34,8 @@ SessionExpired, ) -from ...._async_compat import ( - mark_sync_test, - Mock, -) -from ..work import FakeConnection +from ...._async_compat import mark_sync_test +from ..work import fake_connection_generator ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -49,12 +44,12 @@ @pytest.fixture() -def opener(): +def opener(fake_connection_generator, mocker): def open_(addr, timeout): - connection = FakeConnection() + connection = fake_connection_generator() connection.addr = addr connection.timeout = timeout - route_mock = Mock() + route_mock = mocker.Mock() route_mock.return_value = [{ "ttl": 1000, "servers": [ @@ -67,7 +62,7 @@ def open_(addr, timeout): opener_.connections.append(connection) return connection - opener_ = Mock() + opener_ = mocker.Mock() opener_.connections = [] opener_.side_effect = open_ return opener_ diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py index 1ca0106da..406e15a27 100644 --- a/tests/unit/sync/test_addressing.py +++ b/tests/unit/sync/test_addressing.py @@ -20,7 +20,6 @@ AF_INET, AF_INET6, ) -import unittest.mock as mock import pytest @@ -34,13 +33,6 @@ from ..._async_compat import mark_sync_test -mock_socket_ipv4 = mock.Mock() -mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port) - -mock_socket_ipv6 = mock.Mock() -mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id) - - @mark_sync_test def test_address_resolve(): address = Address(("127.0.0.1", 7687)) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 8b45a7f48..ecdede08e 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -33,11 +33,7 @@ ) from neo4j.exceptions import ConfigurationError -from ..._async_compat import ( - mark_sync_test, - mock, -) -from .work import FakeConnection +from ..._async_compat import mark_sync_test @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @@ -58,7 +54,8 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): driver.close() -@pytest.mark.parametrize("protocol", ("neo4j://", "neo4j+s://", "neo4j+ssc://")) +@pytest.mark.parametrize("protocol", + ("neo4j://", "neo4j+s://", "neo4j+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @pytest.mark.parametrize("port", (":1234", "", ":7687")) @@ -175,28 +172,26 @@ def test_driver_opens_write_session_by_default(uri, mocker): # to get hold of the actual home database (which won't work in this # unittest) with driver.session(database="foobar") as session: - with mock.patch.object( - session._pool, "acquire", autospec=True - ) as acquire_mock: - with mock.patch.object( - Transaction, "_begin", autospec=True - ) as tx_begin_mock: - tx = session.begin_transaction() - acquire_mock.assert_called_once_with( - access_mode=WRITE_ACCESS, - timeout=mocker.ANY, - database=mocker.ANY, - bookmarks=mocker.ANY - ) - tx_begin_mock.assert_called_once_with( - tx, - mocker.ANY, - mocker.ANY, - mocker.ANY, - WRITE_ACCESS, - mocker.ANY, - mocker.ANY - ) + acquire_mock = mocker.patch.object(session._pool, "acquire", + autospec=True) + tx_begin_mock = mocker.patch.object(Transaction, "_begin", + autospec=True) + tx = session.begin_transaction() + acquire_mock.assert_called_once_with( + access_mode=WRITE_ACCESS, + timeout=mocker.ANY, + database=mocker.ANY, + bookmarks=mocker.ANY + ) + tx_begin_mock.assert_called_once_with( + tx, + mocker.ANY, + mocker.ANY, + mocker.ANY, + WRITE_ACCESS, + mocker.ANY, + mocker.ANY + ) driver.close() @@ -206,12 +201,12 @@ def test_driver_opens_write_session_by_default(uri, mocker): "neo4j://127.0.0.1:9000", )) @mark_sync_test -def test_verify_connectivity(uri): +def test_verify_connectivity(uri, mocker): driver = GraphDatabase.driver(uri) + pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: - with mock.patch.object(driver, "_pool", autospec=True) as pool_mock: - ret = driver.verify_connectivity() + ret = driver.verify_connectivity() finally: driver.close() @@ -231,12 +226,13 @@ def test_verify_connectivity(uri): {"fetch_size": 69}, )) @mark_sync_test -def test_verify_connectivity_parameters_are_deprecated(uri, kwargs): +def test_verify_connectivity_parameters_are_deprecated(uri, kwargs, + mocker): driver = GraphDatabase.driver(uri) + mocker.patch.object(driver, "_pool", autospec=True) try: - with mock.patch.object(driver, "_pool", autospec=True): - with pytest.warns(DeprecationWarning, match="configuration"): - driver.verify_connectivity(**kwargs) + with pytest.warns(DeprecationWarning, match="configuration"): + driver.verify_connectivity(**kwargs) finally: driver.close() diff --git a/tests/unit/sync/work/__init__.py b/tests/unit/sync/work/__init__.py index 2613b53d2..06447a991 100644 --- a/tests/unit/sync/work/__init__.py +++ b/tests/unit/sync/work/__init__.py @@ -18,5 +18,5 @@ from ._fake_connection import ( fake_connection, - FakeConnection, + fake_connection_generator, ) diff --git a/tests/unit/sync/work/_fake_connection.py b/tests/unit/sync/work/_fake_connection.py index 72977898d..557c333b4 100644 --- a/tests/unit/sync/work/_fake_connection.py +++ b/tests/unit/sync/work/_fake_connection.py @@ -23,88 +23,90 @@ from neo4j import ServerInfo from neo4j._sync.io import Bolt -from ...._async_compat import ( - Mock, - mock, -) - - -class FakeConnection(mock.NonCallableMagicMock): - callbacks = [] - server_info = ServerInfo("127.0.0.1", (4, 3)) - - def __init__(self, *args, **kwargs): - kwargs["spec"] = Bolt - super().__init__(*args, **kwargs) - self.attach_mock(Mock(return_value=True), "is_reset_mock") - self.attach_mock(Mock(return_value=False), "defunct") - self.attach_mock(Mock(return_value=False), "stale") - self.attach_mock(Mock(return_value=False), "closed") - self.attach_mock(Mock(), "unresolved_address") - - def close_side_effect(): - self.closed.return_value = True - - self.attach_mock(Mock(side_effect=close_side_effect), - "close") - - @property - def is_reset(self): - if self.closed.return_value or self.defunct.return_value: - raise AssertionError( - "is_reset should not be called on a closed or defunct " - "connection." - ) - return self.is_reset_mock() - - def fetch_message(self, *args, **kwargs): - if self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_message")(*args, **kwargs) - - def fetch_all(self, *args, **kwargs): - while self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_all")(*args, **kwargs) - - def __getattr__(self, name): - parent = super() - - def build_message_handler(name): - def func(*args, **kwargs): - def callback(): - for cb_name, param_count in ( - ("on_success", 1), - ("on_summary", 0) - ): - cb = kwargs.get(cb_name, None) - if callable(cb): - try: - param_count = \ - len(inspect.signature(cb).parameters) - except ValueError: - # e.g. built-in method as cb - pass - if param_count == 1: - res = cb({}) - else: - res = cb() - try: - res # maybe the callback is async - except TypeError: - pass # or maybe it wasn't ;) - self.callbacks.append(callback) - - return func - - method_mock = parent.__getattr__(name) - if name in ("run", "commit", "pull", "rollback", "discard"): - method_mock.side_effect = build_message_handler(name) - return method_mock + +@pytest.fixture +def fake_connection_generator(session_mocker): + mock = session_mocker.mock_module + + class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def __init__(self, *args, **kwargs): + kwargs["spec"] = Bolt + super().__init__(*args, **kwargs) + self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(mock.Mock(), "unresolved_address") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.Mock(side_effect=close_side_effect), + "close") + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError( + "is_reset should not be called on a closed or defunct " + "connection." + ) + return self.is_reset_mock() + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + res = cb({}) + else: + res = cb() + try: + res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return FakeConnection @pytest.fixture -def fake_connection(): - return FakeConnection() +def fake_connection(fake_connection_generator): + return fake_connection_generator() diff --git a/tests/unit/sync/work/conftest.py b/tests/unit/sync/work/conftest.py new file mode 100644 index 000000000..6302829c2 --- /dev/null +++ b/tests/unit/sync/work/conftest.py @@ -0,0 +1,4 @@ +from ._fake_connection import ( + fake_connection, + fake_connection_generator, +) diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 727f3330f..f2a8aa595 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -22,6 +22,7 @@ from neo4j import ( Bookmarks, + ManagedTransaction, Session, SessionConfig, Transaction, @@ -29,28 +30,24 @@ ) from neo4j._sync.io._pool import IOPool -from ...._async_compat import ( - mark_sync_test, - Mock, - mock, -) -from ._fake_connection import FakeConnection +from ...._async_compat import mark_sync_test +from ._fake_connection import fake_connection_generator @pytest.fixture() -def pool(): - pool = Mock(spec=IOPool) - pool.acquire.side_effect = iter(FakeConnection, 0) +def pool(fake_connection_generator, mocker): + pool = mocker.Mock(spec=IOPool) + pool.acquire.side_effect = iter(fake_connection_generator, 0) return pool @mark_sync_test -def test_session_context_calls_close(): +def test_session_context_calls_close(mocker): s = Session(None, SessionConfig()) - with mock.patch.object(s, 'close', autospec=True) as mock_close: - with s: - pass - mock_close.assert_called_once_with() + mock_close = mocker.patch.object(s, 'close', autospec=True) + with s: + pass + mock_close.assert_called_once_with() @pytest.mark.parametrize("test_run_args", ( @@ -239,7 +236,7 @@ def test_session_run_wrong_types(pool, query, error_type): @mark_sync_test def test_tx_function_argument_type(pool, tx_type): def work(tx): - assert isinstance(tx, Transaction) + assert isinstance(tx, ManagedTransaction) with Session(pool, SessionConfig()) as session: getattr(session, tx_type)(work) @@ -257,7 +254,7 @@ def work(tx): def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): @unit_of_work(**decorator_kwargs) def work(tx): - assert isinstance(tx, Transaction) + assert isinstance(tx, ManagedTransaction) with Session(pool, SessionConfig()) as session: getattr(session, tx_type)(work) diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index b5b40283c..56aa2bc64 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -26,7 +26,7 @@ Transaction, ) -from ._fake_connection import fake_connection +from ...._async_compat import mark_sync_test @pytest.mark.parametrize(("explicit_commit", "close"), ( @@ -34,13 +34,15 @@ (True, False), (True, True), )) -def test_transaction_context_when_committing(mocker, fake_connection, - explicit_commit, close): - on_closed = MagicMock() - on_error = MagicMock() +@mark_sync_test +def test_transaction_context_when_committing( + mocker, fake_connection, explicit_commit, close +): + on_closed = mocker.Mock() + on_error = mocker.Mock() tx = Transaction(fake_connection, 2, on_closed, on_error) - mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) - mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) + mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with tx as tx_: assert mock_commit.call_count == 0 assert mock_rollback.call_count == 0 @@ -48,7 +50,7 @@ def test_transaction_context_when_committing(mocker, fake_connection, if explicit_commit: tx_.commit() mock_commit.assert_called_once_with() - assert tx.closed() + assert tx_.closed() if close: tx_.close() assert tx_.closed() @@ -62,13 +64,15 @@ def test_transaction_context_when_committing(mocker, fake_connection, (False, True), (True, True), )) -def test_transaction_context_with_explicit_rollback(mocker, fake_connection, - rollback, close): - on_closed = MagicMock() - on_error = MagicMock() +@mark_sync_test +def test_transaction_context_with_explicit_rollback( + mocker, fake_connection, rollback, close +): + on_closed = mocker.Mock() + on_error = mocker.Mock() tx = Transaction(fake_connection, 2, on_closed, on_error) - mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) - mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) + mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with tx as tx_: assert mock_commit.call_count == 0 assert mock_rollback.call_count == 0 @@ -86,15 +90,18 @@ def test_transaction_context_with_explicit_rollback(mocker, fake_connection, assert tx_.closed() -def test_transaction_context_calls_rollback_on_error(mocker, fake_connection): +@mark_sync_test +def test_transaction_context_calls_rollback_on_error( + mocker, fake_connection +): class OopsError(RuntimeError): pass on_closed = MagicMock() on_error = MagicMock() tx = Transaction(fake_connection, 2, on_closed, on_error) - mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) - mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) + mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with pytest.raises(OopsError): with tx as tx_: assert mock_commit.call_count == 0 @@ -112,8 +119,10 @@ class OopsError(RuntimeError): ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), ({"x": uuid4()}, TypeError), )) -def test_transaction_run_with_invalid_parameters(fake_connection, parameters, - error_type): +@mark_sync_test +def test_transaction_run_with_invalid_parameters( + fake_connection, parameters, error_type +): on_closed = MagicMock() on_error = MagicMock() tx = Transaction(fake_connection, 2, on_closed, on_error) @@ -121,6 +130,7 @@ def test_transaction_run_with_invalid_parameters(fake_connection, parameters, tx.run("RETURN $x", **parameters) +@mark_sync_test def test_transaction_run_takes_no_query_object(fake_connection): on_closed = MagicMock() on_error = MagicMock() @@ -129,10 +139,14 @@ def test_transaction_run_takes_no_query_object(fake_connection): tx.run(Query("RETURN 1")) -def test_transaction_rollbacks_on_open_connections(fake_connection): - tx = Transaction(fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) +@mark_sync_test +def test_transaction_rollbacks_on_open_connections( + fake_connection +): + tx = Transaction( + fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) with tx as tx_: fake_connection.is_reset_mock.return_value = False fake_connection.is_reset_mock.reset_mock() @@ -142,44 +156,56 @@ def test_transaction_rollbacks_on_open_connections(fake_connection): fake_connection.rollback.assert_called_once() -def test_transaction_no_rollback_on_reset_connections(fake_connection): - tx = Transaction(fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) +@mark_sync_test +def test_transaction_no_rollback_on_reset_connections( + fake_connection +): + tx = Transaction( + fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) with tx as tx_: fake_connection.is_reset_mock.return_value = True fake_connection.is_reset_mock.reset_mock() tx_.rollback() fake_connection.is_reset_mock.assert_called_once() - fake_connection.reset.asset_not_called() - fake_connection.rollback.asset_not_called() + fake_connection.reset.assert_not_called() + fake_connection.rollback.assert_not_called() -def test_transaction_no_rollback_on_closed_connections(fake_connection): - tx = Transaction(fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) +@mark_sync_test +def test_transaction_no_rollback_on_closed_connections( + fake_connection +): + tx = Transaction( + fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) with tx as tx_: fake_connection.closed.return_value = True fake_connection.closed.reset_mock() fake_connection.is_reset_mock.reset_mock() tx_.rollback() fake_connection.closed.assert_called_once() - fake_connection.is_reset_mock.asset_not_called() - fake_connection.reset.asset_not_called() - fake_connection.rollback.asset_not_called() + fake_connection.is_reset_mock.assert_not_called() + fake_connection.reset.assert_not_called() + fake_connection.rollback.assert_not_called() -def test_transaction_no_rollback_on_defunct_connections(fake_connection): - tx = Transaction(fake_connection, 2, - lambda *args, **kwargs: None, - lambda *args, **kwargs: None) +@mark_sync_test +def test_transaction_no_rollback_on_defunct_connections( + fake_connection +): + tx = Transaction( + fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None + ) with tx as tx_: fake_connection.defunct.return_value = True fake_connection.defunct.reset_mock() fake_connection.is_reset_mock.reset_mock() tx_.rollback() fake_connection.defunct.assert_called_once() - fake_connection.is_reset_mock.asset_not_called() - fake_connection.reset.asset_not_called() - fake_connection.rollback.asset_not_called() + fake_connection.is_reset_mock.assert_not_called() + fake_connection.reset.assert_not_called() + fake_connection.rollback.assert_not_called()