Skip to content

Commit

Permalink
add a lot more coverage for failure cases
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Oct 10, 2024
1 parent ec95090 commit f530bb9
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,6 +2108,143 @@ async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]:
assert len(conn._resolve_host_tasks) == 0


async def test_multiple_dns_resolution_requests_first_cancelled(
loop: asyncio.AbstractEventLoop,
) -> None:
"""Verify that first DNS resolution cancellation does not make other resolutions fail."""

async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]:
"""Delayed resolve() task."""
for _ in range(3):
await asyncio.sleep(0)
return [
{
"hostname": "localhost",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
},
]

conn = aiohttp.TCPConnector(force_close=True)
req = ClientRequest(
"GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock()
)
with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch(
"aiohttp.connector.aiohappyeyeballs.start_connection",
side_effect=OSError(1, "Forced connection to fail"),
):
task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout()))

# Let it create the internal task
await asyncio.sleep(0)
# Let that task start running
await asyncio.sleep(0)

# Ensure the task is running
assert len(conn._resolve_host_tasks) == 1

task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout()))
task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout()))

task1.cancel()
with pytest.raises(asyncio.CancelledError):
await task1

# The second and third tasks should still make the connection
# even if the first one is cancelled
with pytest.raises(
aiohttp.ClientConnectorError, match="Forced connection to fail"
):
await task2
with pytest.raises(
aiohttp.ClientConnectorError, match="Forced connection to fail"
):
await task3

# Verify the the task is finished
assert len(conn._resolve_host_tasks) == 0


async def test_multiple_dns_resolution_requests_first_fails_second_successful(
loop: asyncio.AbstractEventLoop,
) -> None:
"""Verify that first DNS resolution fails the first time and is successful the second time."""
attempt = 0

async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]:
"""Delayed resolve() task."""
nonlocal attempt
for _ in range(3):
await asyncio.sleep(0)
attempt += 1
if attempt == 1:
raise OSError(None, "DNS Resolution mock failure")
return [
{
"hostname": "localhost",
"host": "127.0.0.1",
"port": 80,
"family": socket.AF_INET,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
},
]

conn = aiohttp.TCPConnector(force_close=True)
req = ClientRequest(
"GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock()
)
with mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch(
"aiohttp.connector.aiohappyeyeballs.start_connection",
side_effect=OSError(1, "Forced connection to fail"),
):
task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout()))

# Let it create the internal task
await asyncio.sleep(0)
# Let that task start running
await asyncio.sleep(0)

# Ensure the task is running
assert len(conn._resolve_host_tasks) == 1

task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout()))

with pytest.raises(
aiohttp.ClientConnectorError, match="DNS Resolution mock failure"
):
await task1

assert len(conn._resolve_host_tasks) == 0
# The second task should also get the dns resolution failure
with pytest.raises(
aiohttp.ClientConnectorError, match="DNS Resolution mock failure"
):
await task2

# The third task is created after the resolution finished so
# it should try again and succeed
task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout()))
# Let it create the internal task
await asyncio.sleep(0)
# Let that task start running
await asyncio.sleep(0)

# Ensure the task is running
assert len(conn._resolve_host_tasks) == 1

with pytest.raises(
aiohttp.ClientConnectorError, match="Forced connection to fail"
):
await task3

# Verify the the task is finished
assert len(conn._resolve_host_tasks) == 0


async def test_close_abort_closed_transports(loop: asyncio.AbstractEventLoop) -> None:
tr = mock.Mock()

Expand Down

0 comments on commit f530bb9

Please sign in to comment.