Skip to content

Commit

Permalink
Remove WSSConnector TLS presence check (#4695)
Browse files Browse the repository at this point in the history
Let it fail at connection time if TLS is not properly configured

Co-authored-by: Matthew Rocklin <mrocklin@gmail.com>
Co-authored-by: James Bourbeau <jrbourbeau@gmail.com>
  • Loading branch information
3 people authored Apr 13, 2021
1 parent 053f99b commit 6b69342
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
3 changes: 2 additions & 1 deletion distributed/comm/tests/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from distributed import Client, Scheduler, Worker
from distributed.comm import connect, listen, ws
from distributed.comm.core import FatalCommClosedError
from distributed.comm.registry import backends, get_backend
from distributed.security import Security
from distributed.utils_test import ( # noqa: F401
Expand Down Expand Up @@ -71,7 +72,7 @@ async def test_expect_ssl_context(cleanup):
server_ctx = get_server_ssl_context()

async with listen("wss://", lambda comm: comm, ssl_context=server_ctx) as listener:
with pytest.raises(TypeError):
with pytest.raises(FatalCommClosedError, match="TLS expects a `ssl_context` *"):
comm = await connect(listener.contact_address)


Expand Down
7 changes: 5 additions & 2 deletions distributed/comm/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,10 @@ async def connect(self, address, deserialize=True, **connection_args):
except StreamClosedError as e:
convert_stream_closed_error(self, e)
except SSLError as err:
raise FatalCommClosedError() from err
raise FatalCommClosedError(
"TLS expects a `ssl_context` argument of type "
"ssl.SSLContext (perhaps check your TLS configuration?)"
) from err
return self.comm_class(sock, deserialize=deserialize)

def _get_connect_args(self, **connection_args):
Expand All @@ -388,7 +391,7 @@ class WSSConnector(WSConnector):
comm_class = WSS

def _get_connect_args(self, **connection_args):
ctx = _expect_tls_context(connection_args)
ctx = connection_args.get("ssl_context")
return {"ssl_options": ctx, **connection_args.get("extra_conn_args", {})}


Expand Down

0 comments on commit 6b69342

Please sign in to comment.