Skip to content

Commit

Permalink
[PR #8672/c3219bf backport][3.10] Fix TCPConnector doing blocking I/O…
Browse files Browse the repository at this point in the history
… in the event loop to create the SSLContext (#8673)

Co-authored-by: Sam Bull <git@sambull.org>
Co-authored-by: pre-commit-ci[bot]
  • Loading branch information
bdraco and Dreamsorcerer authored Aug 10, 2024
1 parent f96182a commit f3fcba4
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 59 deletions.
3 changes: 3 additions & 0 deletions CHANGES/8672.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed :py:class:`aiohttp.TCPConnector` doing blocking I/O in the event loop to create the ``SSLContext`` -- by :user:`bdraco`.

The blocking I/O would only happen once per verify mode. However, it could cause the event loop to block for a long time if the ``SSLContext`` creation is slow, which is more likely during startup when the disk cache is not yet present.
104 changes: 64 additions & 40 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@
)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
from .helpers import ceil_timeout, is_ip_address, noop, sentinel
from .helpers import (
ceil_timeout,
is_ip_address,
noop,
sentinel,
set_exception,
set_result,
)
from .locks import EventResultOrError
from .resolver import DefaultResolver

Expand Down Expand Up @@ -771,6 +778,7 @@ class TCPConnector(BaseConnector):
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
_made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {}

def __init__(
self,
Expand Down Expand Up @@ -969,29 +977,24 @@ async def _create_connection(
return proto

@staticmethod
@functools.lru_cache(None)
def _make_ssl_context(verified: bool) -> SSLContext:
"""Create SSL context.
This method is not async-friendly and should be called from a thread
because it will load certificates from disk and do other blocking I/O.
"""
if verified:
return ssl.create_default_context()
else:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
try:
sslcontext.options |= ssl.OP_NO_COMPRESSION
except AttributeError as attr_err:
warnings.warn(
"{!s}: The Python interpreter is compiled "
"against OpenSSL < 1.0.0. Ref: "
"https://docs.python.org/3/library/ssl.html"
"#ssl.OP_NO_COMPRESSION".format(attr_err),
)
sslcontext.set_default_verify_paths()
return sslcontext

def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
sslcontext.options |= ssl.OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
return sslcontext

async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context
0. if req.ssl is false, return None
Expand All @@ -1005,25 +1008,46 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
3. if verify_ssl is False in req, generate a SSL context that
won't verify
"""
if req.is_ssl():
if ssl is None: # pragma: no cover
raise RuntimeError("SSL is not supported.")
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
return self._make_ssl_context(True)
else:
if not req.is_ssl():
return None

if ssl is None: # pragma: no cover
raise RuntimeError("SSL is not supported.")
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
return await self._make_or_get_ssl_context(True)

async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext:
"""Create or get cached SSL context."""
try:
return await self._made_ssl_context[verified]
except KeyError:
loop = self._loop
future = loop.create_future()
self._made_ssl_context[verified] = future
try:
result = await loop.run_in_executor(
None, self._make_ssl_context, verified
)
# BaseException is used since we might get CancelledError
except BaseException as ex:
del self._made_ssl_context[verified]
set_exception(future, ex)
raise
else:
set_result(future, result)
return result

def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
if isinstance(ret, Fingerprint):
Expand Down Expand Up @@ -1180,7 +1204,7 @@ async def _start_tls_connection(
# `req.is_ssl()` evaluates to `False` which is never gonna happen
# in this code path. Of course, it's rather fragile
# maintainability-wise but this is to be solved separately.
sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req))
sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req))

try:
async with ceil_timeout(
Expand Down Expand Up @@ -1258,7 +1282,7 @@ async def _create_direct_connection(
*,
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.Transport, ResponseHandler]:
sslcontext = self._get_ssl_context(req)
sslcontext = await self._get_ssl_context(req)
fingerprint = self._get_fingerprint(req)

host = req.url.raw_host
Expand Down
78 changes: 60 additions & 18 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,23 +1540,23 @@ async def test_tcp_connector_clear_dns_cache_bad_args(loop) -> None:
conn.clear_dns_cache("localhost")


async def test_dont_recreate_ssl_context(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
ctx = conn._make_ssl_context(True)
assert ctx is conn._make_ssl_context(True)
async def test_dont_recreate_ssl_context() -> None:
conn = aiohttp.TCPConnector()
ctx = await conn._make_or_get_ssl_context(True)
assert ctx is await conn._make_or_get_ssl_context(True)


async def test_dont_recreate_ssl_context2(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
ctx = conn._make_ssl_context(False)
assert ctx is conn._make_ssl_context(False)
async def test_dont_recreate_ssl_context2() -> None:
conn = aiohttp.TCPConnector()
ctx = await conn._make_or_get_ssl_context(False)
assert ctx is await conn._make_or_get_ssl_context(False)


async def test___get_ssl_context1(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
async def test___get_ssl_context1() -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = False
assert conn._get_ssl_context(req) is None
assert await conn._get_ssl_context(req) is None


async def test___get_ssl_context2(loop) -> None:
Expand All @@ -1565,7 +1565,7 @@ async def test___get_ssl_context2(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = ctx
assert conn._get_ssl_context(req) is ctx
assert await conn._get_ssl_context(req) is ctx


async def test___get_ssl_context3(loop) -> None:
Expand All @@ -1574,7 +1574,7 @@ async def test___get_ssl_context3(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert conn._get_ssl_context(req) is ctx
assert await conn._get_ssl_context(req) is ctx


async def test___get_ssl_context4(loop) -> None:
Expand All @@ -1583,7 +1583,9 @@ async def test___get_ssl_context4(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = False
assert conn._get_ssl_context(req) is conn._make_ssl_context(False)
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)


async def test___get_ssl_context5(loop) -> None:
Expand All @@ -1592,15 +1594,55 @@ async def test___get_ssl_context5(loop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest())
assert conn._get_ssl_context(req) is conn._make_ssl_context(False)
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)


async def test___get_ssl_context6(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
async def test___get_ssl_context6() -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True)


async def test_ssl_context_once() -> None:
"""Test the ssl context is created only once and shared between connectors."""
conn1 = aiohttp.TCPConnector()
conn2 = aiohttp.TCPConnector()
conn3 = aiohttp.TCPConnector()

req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert conn._get_ssl_context(req) is conn._make_ssl_context(True)
assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context
assert True in conn1._made_ssl_context


@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError])
async def test_ssl_context_creation_raises(exception: BaseException) -> None:
"""Test that we try again if SSLContext creation fails the first time."""
conn = aiohttp.TCPConnector()
conn._made_ssl_context.clear()

with mock.patch.object(
conn, "_make_ssl_context", side_effect=exception
), pytest.raises( # type: ignore[call-overload]
exception
):
await conn._make_or_get_ssl_context(True)

assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext)


async def test_close_twice(loop) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ async def make_conn():
self.loop.start_tls.assert_called_with(
mock.ANY,
mock.ANY,
connector._make_ssl_context(True),
self.loop.run_until_complete(connector._make_or_get_ssl_context(True)),
server_hostname="www.python.org",
ssl_handshake_timeout=mock.ANY,
)
Expand Down

0 comments on commit f3fcba4

Please sign in to comment.