From e11879ef4066ebe7735c2dd6bb1ca49065e3d415 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 11 Aug 2023 10:39:16 +0200 Subject: [PATCH] Fix pool closing connections too aggressively (#955) Whenever a new routing table was fetched, the pool would close all connections to servers that were not part of the routing table. However, it might well be, that a missing server is present still in the routing table for another database. Hence, the pool now checks the routing tables for all databases before deciding which connections are no longer needed. --- src/neo4j/_async/io/_pool.py | 10 +- src/neo4j/_sync/io/_pool.py | 10 +- tests/unit/async_/io/test_neo4j_pool.py | 157 +++++++++++++++++++----- tests/unit/sync/io/test_neo4j_pool.py | 157 +++++++++++++++++++----- 4 files changed, 262 insertions(+), 72 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index e854fe987..619a662f1 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -813,8 +813,13 @@ async def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") async def update_connection_pool(self, *, database): - routing_table = await self.get_or_create_routing_table(database) - servers = routing_table.servers() + async with self.refresh_lock: + routing_tables = [await self.get_or_create_routing_table(database)] + for db in self.routing_tables.keys(): + if db == database: + continue + routing_tables.append(self.routing_tables[db]) + servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: await super(AsyncNeo4jPool, self).deactivate(address) @@ -960,6 +965,7 @@ async def deactivate(self, address): async def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ + # FIXME: only need to remove the writer for a specific database log.debug("[#0000] _: removing writer %r", address) async with self.refresh_lock: for database in self.routing_tables.keys(): diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index b8adff017..73f77944a 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -810,8 +810,13 @@ def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): - routing_table = self.get_or_create_routing_table(database) - servers = routing_table.servers() + with self.refresh_lock: + routing_tables = [self.get_or_create_routing_table(database)] + for db in self.routing_tables.keys(): + if db == database: + continue + routing_tables.append(self.routing_tables[db]) + servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: super(Neo4jPool, self).deactivate(address) @@ -957,6 +962,7 @@ def deactivate(self, address): def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ + # FIXME: only need to remove the writer for a specific database log.debug("[#0000] _: removing writer %r", address) with self.refresh_lock: for database in self.routing_tables.keys(): diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index ca27d92f3..a8249eed6 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -17,6 +17,7 @@ import inspect +from collections import defaultdict import pytest @@ -50,17 +51,23 @@ ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host") +READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host") +WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") @pytest.fixture -def routing_failure_opener(async_fake_connection_generator, mocker): - def make_opener(failures=None): +def custom_routing_opener(async_fake_connection_generator, mocker): + def make_opener(failures=None, get_readers=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) if res is None: + if get_readers is not None: + readers = get_readers(kwargs.get("database")) + else: + readers = [str(READER1_ADDRESS)] return [{ "ttl": 1000, "servers": [ @@ -68,8 +75,8 @@ def routing_side_effect(*args, **kwargs): str(ROUTER2_ADDRESS), str(ROUTER3_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"}, ], }] raise res @@ -96,8 +103,8 @@ async def open_(addr, auth, timeout): @pytest.fixture -def opener(routing_failure_opener): - return routing_failure_opener() +def opener(custom_routing_opener): + return custom_routing_opener() def _pool_config(): @@ -177,9 +184,9 @@ async def test_chooses_right_connection_type(opener, type_): ) await pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER1_ADDRESS @mark_async_test @@ -298,9 +305,9 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): pool = _simple_pool(opener) - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -311,11 +318,11 @@ async def test_acquire_performs_liveness_check_on_existing_connection( ): pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -326,7 +333,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is not cx2 assert cx1.unresolved_address == cx2.unresolved_address @@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS - assert cx2.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS + assert cx2.unresolved_address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx3 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx): async def close_side_effect(): cx.closed.return_value = True cx.defunct.return_value = True - await pool.deactivate(READER_ADDRESS) + await pool.deactivate(READER1_ADDRESS) cx.attach_mock(mocker.AsyncMock(side_effect=close_side_effect), "close") @@ -470,9 +477,9 @@ async def test__acquire_new_later_with_room(opener): pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 1 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 1 assert callable(creator) if AsyncUtil.is_async_code: assert inspect.iscoroutinefunction(creator) @@ -487,9 +494,9 @@ async def test__acquire_new_later_without_room(opener): ) _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 0 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 0 assert creator is None @@ -519,8 +526,8 @@ async def test_passes_pool_config_to_connection(mocker): "Neo.ClientError.Security.AuthorizationExpired"), )) @mark_async_test -async def test_discovery_is_retried(routing_failure_opener, error): - opener = routing_failure_opener([ +async def test_discovery_is_retried(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -563,8 +570,8 @@ async def test_discovery_is_retried(routing_failure_opener, error): ) )) @mark_async_test -async def test_fast_failing_discovery(routing_failure_opener, error): - opener = routing_failure_opener([ +async def test_fast_failing_discovery(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -648,3 +655,85 @@ async def test_connection_error_callback( cx.mark_unauthenticated.assert_not_called() for cx in cxs_write: cx.mark_unauthenticated.assert_not_called() + + +@mark_async_test +async def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): + readers = {"db1": [str(READER1_ADDRESS)]} + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx1.unresolved_address == READER1_ADDRESS + await pool.release(cx1) + + cx1.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + + # force RT refresh, returning a different reader + del pool.routing_tables["db1"] + readers["db1"] = [str(READER2_ADDRESS)] + + cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx2.unresolved_address == READER2_ADDRESS + + cx1.close.assert_awaited_once() + assert len(pool.connections[READER1_ADDRESS]) == 0 + + await pool.release(cx2) + assert len(pool.connections[READER2_ADDRESS]) == 1 + + +@mark_async_test +async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( + custom_routing_opener +): + readers = { + "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], + "db2": [str(READER1_ADDRESS)] + } + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + await pool.release(cx1) + assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS) + reader1_connection_count = len(pool.connections[READER1_ADDRESS]) + reader2_connection_count = len(pool.connections[READER2_ADDRESS]) + assert reader1_connection_count + reader2_connection_count == 1 + + cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + await pool.release(cx2) + assert cx2.unresolved_address == READER1_ADDRESS + cx1.close.assert_not_called() + cx2.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + + + # force RT refresh, returning a different reader + del pool.routing_tables["db2"] + readers["db2"] = [str(READER3_ADDRESS)] + + cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + await pool.release(cx3) + assert cx3.unresolved_address == READER3_ADDRESS + + cx1.close.assert_not_called() + cx2.close.assert_not_called() + cx3.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + assert len(pool.connections[READER3_ADDRESS]) == 1 diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index cfaaf1f34..3a5a2e79b 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -17,6 +17,7 @@ import inspect +from collections import defaultdict import pytest @@ -50,17 +51,23 @@ ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host") +READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host") +WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") @pytest.fixture -def routing_failure_opener(fake_connection_generator, mocker): - def make_opener(failures=None): +def custom_routing_opener(fake_connection_generator, mocker): + def make_opener(failures=None, get_readers=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) if res is None: + if get_readers is not None: + readers = get_readers(kwargs.get("database")) + else: + readers = [str(READER1_ADDRESS)] return [{ "ttl": 1000, "servers": [ @@ -68,8 +75,8 @@ def routing_side_effect(*args, **kwargs): str(ROUTER2_ADDRESS), str(ROUTER3_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"}, ], }] raise res @@ -96,8 +103,8 @@ def open_(addr, auth, timeout): @pytest.fixture -def opener(routing_failure_opener): - return routing_failure_opener() +def opener(custom_routing_opener): + return custom_routing_opener() def _pool_config(): @@ -177,9 +184,9 @@ def test_chooses_right_connection_type(opener, type_): ) pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER1_ADDRESS @mark_sync_test @@ -298,9 +305,9 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): pool = _simple_pool(opener) - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -311,11 +318,11 @@ def test_acquire_performs_liveness_check_on_existing_connection( ): pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -326,7 +333,7 @@ def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is not cx2 assert cx1.unresolved_address == cx2.unresolved_address @@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS - assert cx2.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS + assert cx2.unresolved_address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx3 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx): def close_side_effect(): cx.closed.return_value = True cx.defunct.return_value = True - pool.deactivate(READER_ADDRESS) + pool.deactivate(READER1_ADDRESS) cx.attach_mock(mocker.MagicMock(side_effect=close_side_effect), "close") @@ -470,9 +477,9 @@ def test__acquire_new_later_with_room(opener): pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 1 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 1 assert callable(creator) if Util.is_async_code: assert inspect.iscoroutinefunction(creator) @@ -487,9 +494,9 @@ def test__acquire_new_later_without_room(opener): ) _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 0 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 0 assert creator is None @@ -519,8 +526,8 @@ def test_passes_pool_config_to_connection(mocker): "Neo.ClientError.Security.AuthorizationExpired"), )) @mark_sync_test -def test_discovery_is_retried(routing_failure_opener, error): - opener = routing_failure_opener([ +def test_discovery_is_retried(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -563,8 +570,8 @@ def test_discovery_is_retried(routing_failure_opener, error): ) )) @mark_sync_test -def test_fast_failing_discovery(routing_failure_opener, error): - opener = routing_failure_opener([ +def test_fast_failing_discovery(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -648,3 +655,85 @@ def test_connection_error_callback( cx.mark_unauthenticated.assert_not_called() for cx in cxs_write: cx.mark_unauthenticated.assert_not_called() + + +@mark_sync_test +def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): + readers = {"db1": [str(READER1_ADDRESS)]} + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx1.unresolved_address == READER1_ADDRESS + pool.release(cx1) + + cx1.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + + # force RT refresh, returning a different reader + del pool.routing_tables["db1"] + readers["db1"] = [str(READER2_ADDRESS)] + + cx2 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx2.unresolved_address == READER2_ADDRESS + + cx1.close.assert_called_once() + assert len(pool.connections[READER1_ADDRESS]) == 0 + + pool.release(cx2) + assert len(pool.connections[READER2_ADDRESS]) == 1 + + +@mark_sync_test +def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( + custom_routing_opener +): + readers = { + "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], + "db2": [str(READER1_ADDRESS)] + } + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + pool.release(cx1) + assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS) + reader1_connection_count = len(pool.connections[READER1_ADDRESS]) + reader2_connection_count = len(pool.connections[READER2_ADDRESS]) + assert reader1_connection_count + reader2_connection_count == 1 + + cx2 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + pool.release(cx2) + assert cx2.unresolved_address == READER1_ADDRESS + cx1.close.assert_not_called() + cx2.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + + + # force RT refresh, returning a different reader + del pool.routing_tables["db2"] + readers["db2"] = [str(READER3_ADDRESS)] + + cx3 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + pool.release(cx3) + assert cx3.unresolved_address == READER3_ADDRESS + + cx1.close.assert_not_called() + cx2.close.assert_not_called() + cx3.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + assert len(pool.connections[READER3_ADDRESS]) == 1