Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tightening the runtime type check for ssl (#7698) #8042

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7698.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for passing `True` to `ssl` while deprecating `None`. -- by :user:`xiangyan99`
22 changes: 17 additions & 5 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Set,
Expand Down Expand Up @@ -415,7 +414,7 @@ async def _request(
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
Expand All @@ -432,6 +431,11 @@ async def _request(
if self.closed:
raise RuntimeError("Session is closed")

if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError(
"ssl should be SSLContext, Fingerprint, or bool, "
"got {!r} instead.".format(ssl)
)
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)

if data is not None and json is not None:
Expand Down Expand Up @@ -571,7 +575,7 @@ async def _request(
proxy_auth=proxy_auth,
timer=timer,
session=self,
ssl=ssl,
ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr]
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
Expand Down Expand Up @@ -752,7 +756,7 @@ def ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
Expand Down Expand Up @@ -804,7 +808,7 @@ async def _ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
Expand Down Expand Up @@ -838,6 +842,14 @@ async def _ws_connect(
extstr = ws_ext_gen(compress=compress)
real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr

# For the sake of backward compatibility, if user passes in None, convert it to True
if ssl is None:
warnings.warn(
"ssl=None is deprecated, please use ssl=True",
DeprecationWarning,
stacklevel=2,
)
ssl = True
ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)

# send request
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ def port(self) -> Optional[int]:
return self._conn_key.port

@property
def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]:
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl

def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)

# OSError.__reduce__ does too much black magick
Expand Down Expand Up @@ -221,7 +221,7 @@ def path(self) -> str:

def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)


Expand Down
19 changes: 9 additions & 10 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -151,11 +150,11 @@ def check(self, transport: asyncio.Transport) -> None:
if ssl is not None:
SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
else: # pragma: no cover
SSL_ALLOWED_TYPES = type(None)
SSL_ALLOWED_TYPES = (bool, type(None))


def _merge_ssl_params(
ssl: Union["SSLContext", Literal[False], Fingerprint, None],
ssl: Union["SSLContext", bool, Fingerprint, None],
verify_ssl: Optional[bool],
ssl_context: Optional["SSLContext"],
fingerprint: Optional[bytes],
Expand All @@ -166,7 +165,7 @@ def _merge_ssl_params(
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
if ssl is not True:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
Expand All @@ -179,7 +178,7 @@ def _merge_ssl_params(
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
if ssl is not True:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
Expand All @@ -192,7 +191,7 @@ def _merge_ssl_params(
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
if ssl is not True:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
Expand All @@ -214,7 +213,7 @@ class ConnectionKey:
host: str
port: Optional[int]
is_ssl: bool
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
ssl: Union[SSLContext, bool, Fingerprint]
proxy: Optional[URL]
proxy_auth: Optional[BasicAuth]
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
Expand Down Expand Up @@ -276,7 +275,7 @@ def __init__(
proxy_auth: Optional[BasicAuth] = None,
timer: Optional[BaseTimerContext] = None,
session: Optional["ClientSession"] = None,
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
Expand Down Expand Up @@ -315,7 +314,7 @@ def __init__(
real_response_class = response_class
self.response_class: Type[ClientResponse] = real_response_class
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
self._ssl = ssl if ssl is not None else True # type: ignore[redundant-expr]
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
self.server_hostname = server_hostname

if loop.get_debug():
Expand Down Expand Up @@ -357,7 +356,7 @@ def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

@property
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
return self._ssl

@property
Expand Down
11 changes: 8 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def __init__(
ttl_dns_cache: Optional[int] = 10,
family: int = 0,
ssl_context: Optional[SSLContext] = None,
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
ssl: Union[bool, Fingerprint, SSLContext] = True,
local_addr: Optional[Tuple[str, int]] = None,
resolver: Optional[AbstractResolver] = None,
keepalive_timeout: Union[None, float, object] = sentinel,
Expand All @@ -791,6 +791,11 @@ def __init__(
timeout_ceil_threshold=timeout_ceil_threshold,
)

if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError(
"ssl should be SSLContext, Fingerprint, or bool, "
"got {!r} instead.".format(ssl)
)
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
if resolver is None:
resolver = DefaultResolver(loop=self._loop)
Expand Down Expand Up @@ -965,13 +970,13 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not None:
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not None:
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
return self._make_ssl_context(True)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class TestClientConnectorError:
host="example.com",
port=8080,
is_ssl=False,
ssl=None,
ssl=True,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None,
Expand All @@ -136,7 +136,7 @@ def test_ctor(self) -> None:
assert err.os_error.strerror == "No such file"
assert err.host == "example.com"
assert err.port == 8080
assert err.ssl is None
assert err.ssl is True

def test_pickle(self) -> None:
err = client.ClientConnectorError(
Expand All @@ -153,7 +153,7 @@ def test_pickle(self) -> None:
assert err2.os_error.strerror == "No such file"
assert err2.host == "example.com"
assert err2.port == 8080
assert err2.ssl is None
assert err2.ssl is True
assert err2.foo == "bar"

def test_repr(self) -> None:
Expand All @@ -171,7 +171,7 @@ def test_str(self) -> None:
os_error=OSError(errno.ENOENT, "No such file"),
)
assert str(err) == (
"Cannot connect to host example.com:8080 ssl:" "default [No such file]"
"Cannot connect to host example.com:8080 ssl:default [No such file]"
)


Expand All @@ -180,7 +180,7 @@ class TestClientConnectorCertificateError:
host="example.com",
port=8080,
is_ssl=False,
ssl=None,
ssl=True,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_host_port_default_http(make_request) -> None:
req = make_request("get", "http://python.org/")
assert req.host == "python.org"
assert req.port == 80
assert not req.ssl
assert not req.is_ssl()


def test_host_port_default_https(make_request) -> None:
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_ipv6_default_http_port(make_request) -> None:
req = make_request("get", "http://[2001:db8::1]/")
assert req.host == "2001:db8::1"
assert req.port == 80
assert not req.ssl
assert not req.is_ssl()


def test_ipv6_default_https_port(make_request) -> None:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@
@pytest.fixture()
def key():
# Connection key
return ConnectionKey("localhost", 80, False, None, None, None, None)
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def key2():
# Connection key
return ConnectionKey("localhost", 80, False, None, None, None, None)
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def ssl_key():
# Connection key
return ConnectionKey("localhost", 80, True, None, None, None, None)
return ConnectionKey("localhost", 80, True, True, None, None, None)


@pytest.fixture
Expand Down Expand Up @@ -1467,9 +1467,9 @@ async def test_cleanup_closed_disabled(loop, mocker) -> None:
assert not conn._cleanup_closed_transports


async def test_tcp_connector_ctor(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
assert conn._ssl is None
async def test_tcp_connector_ctor() -> None:
conn = aiohttp.TCPConnector()
assert conn._ssl is True

assert conn.use_dns_cache
assert conn.family == 0
Expand Down Expand Up @@ -1555,7 +1555,7 @@ async def test___get_ssl_context3(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop, ssl=ctx)
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = None
req.ssl = True
assert conn._get_ssl_context(req) is ctx


Expand All @@ -1581,7 +1581,7 @@ async def test___get_ssl_context6(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = None
req.ssl = True
assert conn._get_ssl_context(req) is conn._make_ssl_context(True)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def make_conn():
auth=None,
headers={"Host": "www.python.org"},
loop=self.loop,
ssl=None,
ssl=True,
)

conn.close()
Expand Down Expand Up @@ -150,7 +150,7 @@ async def make_conn():
auth=None,
headers={"Host": "www.python.org", "Foo": "Bar"},
loop=self.loop,
ssl=None,
ssl=True,
)

conn.close()
Expand Down
Loading