diff --git a/CHANGES/7698.feature b/CHANGES/7698.feature new file mode 100644 index 00000000000..e8c4b3fb452 --- /dev/null +++ b/CHANGES/7698.feature @@ -0,0 +1 @@ +Added support for passing `True` to `ssl` while deprecating `None`. -- by :user:`xiangyan99` diff --git a/aiohttp/client.py b/aiohttp/client.py index d08211bd00e..36dbf6a7119 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -22,7 +22,6 @@ Generic, Iterable, List, - Literal, Mapping, Optional, Set, @@ -415,7 +414,7 @@ async def _request( verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, - ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, trace_request_ctx: Optional[SimpleNamespace] = None, @@ -571,7 +570,7 @@ async def _request( proxy_auth=proxy_auth, timer=timer, session=self, - ssl=ssl, + ssl=ssl if ssl is not None else True, server_hostname=server_hostname, proxy_headers=proxy_headers, traces=traces, @@ -752,7 +751,7 @@ def ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -804,7 +803,7 @@ async def _ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -838,6 +837,14 @@ async def _ws_connect( extstr = ws_ext_gen(compress=compress) real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr + # For the sake of backward compatibility, if user passes in None, convert it to True + if ssl is None: + warnings.warn( + "ssl=None is deprecated, please use ssl=True", + DeprecationWarning, + stacklevel=2, + ) + ssl = True ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) # send request diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index d70988f6ede..60bf058e887 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -182,12 +182,12 @@ def port(self) -> Optional[int]: return self._conn_key.port @property - def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]: + def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]: return self._conn_key.ssl def __str__(self) -> str: return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format( - self, self.ssl if self.ssl is not None else "default", self.strerror + self, "default" if self.ssl is True else self.ssl, self.strerror ) # OSError.__reduce__ does too much black magick @@ -221,7 +221,7 @@ def path(self) -> str: def __str__(self) -> str: return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format( - self, self.ssl if self.ssl is not None else "default", self.strerror + self, "default" if self.ssl is True else self.ssl, self.strerror ) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 4ae0ecbcdfb..bb43ae9318d 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -17,7 +17,6 @@ Dict, Iterable, List, - Literal, Mapping, Optional, Tuple, @@ -151,22 +150,22 @@ def check(self, transport: asyncio.Transport) -> None: if ssl is not None: SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None)) else: # pragma: no cover - SSL_ALLOWED_TYPES = type(None) + SSL_ALLOWED_TYPES = (bool, type(None)) def _merge_ssl_params( - ssl: Union["SSLContext", Literal[False], Fingerprint, None], + ssl: Union["SSLContext", bool, Fingerprint], verify_ssl: Optional[bool], ssl_context: Optional["SSLContext"], fingerprint: Optional[bytes], -) -> Union["SSLContext", Literal[False], Fingerprint, None]: +) -> Union["SSLContext", bool, Fingerprint]: if verify_ssl is not None and not verify_ssl: warnings.warn( "verify_ssl is deprecated, use ssl=False instead", DeprecationWarning, stacklevel=3, ) - if ssl is not None: + if ssl is not True: raise ValueError( "verify_ssl, ssl_context, fingerprint and ssl " "parameters are mutually exclusive" @@ -179,7 +178,7 @@ def _merge_ssl_params( DeprecationWarning, stacklevel=3, ) - if ssl is not None: + if ssl is not True: raise ValueError( "verify_ssl, ssl_context, fingerprint and ssl " "parameters are mutually exclusive" @@ -192,7 +191,7 @@ def _merge_ssl_params( DeprecationWarning, stacklevel=3, ) - if ssl is not None: + if ssl is not True: raise ValueError( "verify_ssl, ssl_context, fingerprint and ssl " "parameters are mutually exclusive" @@ -214,7 +213,7 @@ class ConnectionKey: host: str port: Optional[int] is_ssl: bool - ssl: Union[SSLContext, None, Literal[False], Fingerprint] + ssl: Union[SSLContext, bool, Fingerprint] proxy: Optional[URL] proxy_auth: Optional[BasicAuth] proxy_headers_hash: Optional[int] # hash(CIMultiDict) @@ -276,7 +275,7 @@ def __init__( proxy_auth: Optional[BasicAuth] = None, timer: Optional[BaseTimerContext] = None, session: Optional["ClientSession"] = None, - ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, proxy_headers: Optional[LooseHeaders] = None, traces: Optional[List["Trace"]] = None, trust_env: bool = False, @@ -315,7 +314,7 @@ def __init__( real_response_class = response_class self.response_class: Type[ClientResponse] = real_response_class self._timer = timer if timer is not None else TimerNoop() - self._ssl = ssl + self._ssl = ssl if ssl is not None else True self.server_hostname = server_hostname if loop.get_debug(): @@ -357,7 +356,7 @@ def is_ssl(self) -> bool: return self.url.scheme in ("https", "wss") @property - def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]: + def ssl(self) -> Union["SSLContext", bool, Fingerprint]: return self._ssl @property diff --git a/aiohttp/connector.py b/aiohttp/connector.py index baa3a7170f6..d0954355244 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -768,7 +768,7 @@ def __init__( ttl_dns_cache: Optional[int] = 10, family: int = 0, ssl_context: Optional[SSLContext] = None, - ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None, + ssl: Union[bool, Fingerprint, SSLContext] = True, local_addr: Optional[Tuple[str, int]] = None, resolver: Optional[AbstractResolver] = None, keepalive_timeout: Union[None, float, object] = sentinel, @@ -965,13 +965,13 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: sslcontext = req.ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext - if sslcontext is not None: + if sslcontext is not True: # not verified or fingerprinted return self._make_ssl_context(False) sslcontext = self._ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext - if sslcontext is not None: + if sslcontext is not True: # not verified or fingerprinted return self._make_ssl_context(False) return self._make_ssl_context(True) diff --git a/tests/test_client_exceptions.py b/tests/test_client_exceptions.py index 8f34e4cc73c..f70ba5d09a6 100644 --- a/tests/test_client_exceptions.py +++ b/tests/test_client_exceptions.py @@ -119,7 +119,7 @@ class TestClientConnectorError: host="example.com", port=8080, is_ssl=False, - ssl=None, + ssl=True, proxy=None, proxy_auth=None, proxy_headers_hash=None, @@ -136,7 +136,7 @@ def test_ctor(self) -> None: assert err.os_error.strerror == "No such file" assert err.host == "example.com" assert err.port == 8080 - assert err.ssl is None + assert err.ssl is True def test_pickle(self) -> None: err = client.ClientConnectorError( @@ -153,7 +153,7 @@ def test_pickle(self) -> None: assert err2.os_error.strerror == "No such file" assert err2.host == "example.com" assert err2.port == 8080 - assert err2.ssl is None + assert err2.ssl is True assert err2.foo == "bar" def test_repr(self) -> None: @@ -171,7 +171,7 @@ def test_str(self) -> None: os_error=OSError(errno.ENOENT, "No such file"), ) assert str(err) == ( - "Cannot connect to host example.com:8080 ssl:" "default [No such file]" + "Cannot connect to host example.com:8080 ssl:default [No such file]" ) @@ -180,7 +180,7 @@ class TestClientConnectorCertificateError: host="example.com", port=8080, is_ssl=False, - ssl=None, + ssl=True, proxy=None, proxy_auth=None, proxy_headers_hash=None, diff --git a/tests/test_client_fingerprint.py b/tests/test_client_fingerprint.py index b1ae3cae36e..68dd528e0a2 100644 --- a/tests/test_client_fingerprint.py +++ b/tests/test_client_fingerprint.py @@ -37,7 +37,7 @@ def test_fingerprint_check_no_ssl() -> None: def test__merge_ssl_params_verify_ssl() -> None: with pytest.warns(DeprecationWarning): - assert _merge_ssl_params(None, False, None, None) is False + assert _merge_ssl_params(True, False, None, None) is False def test__merge_ssl_params_verify_ssl_conflict() -> None: @@ -50,7 +50,7 @@ def test__merge_ssl_params_verify_ssl_conflict() -> None: def test__merge_ssl_params_ssl_context() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) with pytest.warns(DeprecationWarning): - assert _merge_ssl_params(None, None, ctx, None) is ctx + assert _merge_ssl_params(True, None, ctx, None) is ctx def test__merge_ssl_params_ssl_context_conflict() -> None: @@ -64,7 +64,7 @@ def test__merge_ssl_params_ssl_context_conflict() -> None: def test__merge_ssl_params_fingerprint() -> None: digest = hashlib.sha256(b"123").digest() with pytest.warns(DeprecationWarning): - ret = _merge_ssl_params(None, None, None, digest) + ret = _merge_ssl_params(True, None, None, digest) assert ret.fingerprint == digest diff --git a/tests/test_client_request.py b/tests/test_client_request.py index c8ce98d4034..6521b70ad55 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -166,7 +166,7 @@ def test_host_port_default_http(make_request) -> None: req = make_request("get", "http://python.org/") assert req.host == "python.org" assert req.port == 80 - assert not req.ssl + assert not req.is_ssl() def test_host_port_default_https(make_request) -> None: @@ -400,7 +400,7 @@ def test_ipv6_default_http_port(make_request) -> None: req = make_request("get", "http://[2001:db8::1]/") assert req.host == "2001:db8::1" assert req.port == 80 - assert not req.ssl + assert not req.is_ssl() def test_ipv6_default_https_port(make_request) -> None: diff --git a/tests/test_connector.py b/tests/test_connector.py index 1faec002487..84c03fc6fb5 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -30,19 +30,19 @@ @pytest.fixture() def key(): # Connection key - return ConnectionKey("localhost", 80, False, None, None, None, None) + return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture def key2(): # Connection key - return ConnectionKey("localhost", 80, False, None, None, None, None) + return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture def ssl_key(): # Connection key - return ConnectionKey("localhost", 80, True, None, None, None, None) + return ConnectionKey("localhost", 80, True, True, None, None, None) @pytest.fixture @@ -1467,9 +1467,9 @@ async def test_cleanup_closed_disabled(loop, mocker) -> None: assert not conn._cleanup_closed_transports -async def test_tcp_connector_ctor(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) - assert conn._ssl is None +async def test_tcp_connector_ctor() -> None: + conn = aiohttp.TCPConnector() + assert conn._ssl is True assert conn.use_dns_cache assert conn.family == 0 @@ -1555,7 +1555,7 @@ async def test___get_ssl_context3(loop) -> None: conn = aiohttp.TCPConnector(loop=loop, ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True - req.ssl = None + req.ssl = True assert conn._get_ssl_context(req) is ctx @@ -1581,7 +1581,7 @@ async def test___get_ssl_context6(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) req = mock.Mock() req.is_ssl.return_value = True - req.ssl = None + req.ssl = True assert conn._get_ssl_context(req) is conn._make_ssl_context(True) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 2a8643f5047..f335e42c254 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -92,7 +92,7 @@ async def make_conn(): auth=None, headers={"Host": "www.python.org"}, loop=self.loop, - ssl=None, + ssl=True, ) conn.close() @@ -150,7 +150,7 @@ async def make_conn(): auth=None, headers={"Host": "www.python.org", "Foo": "Bar"}, loop=self.loop, - ssl=None, + ssl=True, ) conn.close()