Skip to content

Commit

Permalink
Fix closure of sockets after failed connection (#795)
Browse files Browse the repository at this point in the history
* Fix closure of sockets after failed connection
* Fix missing await and logging after socket broke
  • Loading branch information
robsdedude authored Sep 6, 2022
1 parent 8cbcd83 commit c468106
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
4 changes: 2 additions & 2 deletions neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def time_remaining():
bolt_cls = AsyncBolt5x0
else:
log.debug("[#%04X] S: <CLOSE>", s.getsockname()[1])
AsyncBoltSocket.close_socket(s)
await AsyncBoltSocket.close_socket(s)

supported_versions = cls.protocol_handlers().keys()
raise BoltHandshakeError(
Expand All @@ -374,7 +374,7 @@ def time_remaining():
finally:
connection.socket.set_deadline(None)
except Exception as e:
log.debug("[#%04X] C: <OPEN FAILED> %r", s.getsockname()[1], e)
log.debug("[#%04X] C: <OPEN FAILED> %r", connection.local_port, e)
connection.kill()
raise

Expand Down
49 changes: 30 additions & 19 deletions neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl):
raise
except (SSLError, CertificateError) as error:
local_port = s.getsockname()[1]
if s:
await cls.close_socket(s)
raise BoltSecurityError(
message="Failed to establish encrypted connection.",
address=(resolved_address.host_name, local_port)
Expand All @@ -261,7 +263,8 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl):
log.debug("[#0000] C: <ERROR> %s %s", type(error).__name__,
" ".join(map(repr, error.args)))
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
s.close()
if s:
await cls.close_socket(s)
raise ServiceUnavailable(
"Failed to establish connection to {!r} (reason {})".format(
resolved_address, error))
Expand Down Expand Up @@ -334,14 +337,20 @@ async def _handshake(self, resolved_address):

@classmethod
async def close_socket(cls, socket_):
try:
if isinstance(socket_, AsyncBoltSocket):
if isinstance(socket_, AsyncBoltSocket):
try:
await socket_.close()
else:
except OSError:
pass
else:
try:
socket_.shutdown(SHUT_RDWR)
except OSError:
pass
try:
socket_.close()
except OSError:
pass
except OSError:
pass

@classmethod
async def connect(cls, address, *, timeout, custom_resolver, ssl_context,
Expand Down Expand Up @@ -463,8 +472,7 @@ def sendall(self, data):
return self._wait_for_io(self._socket.sendall, data)

def close(self):
self._socket.shutdown(SHUT_RDWR)
self._socket.close()
self.close_socket(self._socket)

def kill(self):
self._socket.close()
Expand Down Expand Up @@ -509,7 +517,7 @@ def _connect(cls, resolved_address, timeout, keep_alive):
log.debug("[#0000] C: <ERROR> %s %s", type(error).__name__,
" ".join(map(repr, error.args)))
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
s.close()
cls.close_socket(s)
raise ServiceUnavailable(
"Failed to establish connection to {!r} (reason {})".format(
resolved_address, error))
Expand All @@ -524,6 +532,7 @@ def _secure(cls, s, host, ssl_context):
sni_host = host if HAS_SNI and host else None
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
except (OSError, SSLError, CertificateError) as cause:
cls.close_socket(s)
raise BoltSecurityError(
message="Failed to establish encrypted connection.",
address=(host, local_port)
Expand Down Expand Up @@ -582,20 +591,20 @@ def _handshake(cls, s, resolved_address):
# If no data is returned after a successful select
# response, the server has closed the connection
log.debug("[#%04X] S: <CLOSE>", local_port)
BoltSocket.close_socket(s)
cls.close_socket(s)
raise ServiceUnavailable(
"Connection to {address} closed without handshake response".format(
address=resolved_address))
if data_size != 4:
# Some garbled data has been received
log.debug("[#%04X] S: @*#!", local_port)
s.close()
cls.close_socket(s)
raise BoltProtocolError(
"Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (
resolved_address, data), address=resolved_address)
elif data == b"HTTP":
log.debug("[#%04X] S: <CLOSE>", local_port)
BoltSocket.close_socket(s)
cls.close_socket(s)
raise ServiceUnavailable(
"Cannot to connect to Bolt service on {!r} "
"(looks like HTTP)".format(resolved_address))
Expand All @@ -606,12 +615,14 @@ def _handshake(cls, s, resolved_address):

@classmethod
def close_socket(cls, socket_):
if isinstance(socket_, BoltSocket):
socket_ = socket_._socket
try:
if isinstance(socket_, BoltSocket):
socket_.close()
else:
socket_.shutdown(SHUT_RDWR)
socket_.close()
socket_.shutdown(SHUT_RDWR)
except OSError:
pass
try:
socket_.close()
except OSError:
pass

Expand Down Expand Up @@ -647,11 +658,11 @@ def connect(cls, address, *, timeout, custom_resolver, ssl_context,
log.debug("[#%04X] C: <CONNECTION FAILED> %s", local_port,
err_str)
if s:
BoltSocket.close_socket(s)
cls.close_socket(s)
errors.append(error)
except Exception:
if s:
BoltSocket.close_socket(s)
cls.close_socket(s)
raise
if not errors:
raise ServiceUnavailable(
Expand Down
2 changes: 1 addition & 1 deletion neo4j/_sync/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def time_remaining():
finally:
connection.socket.set_deadline(None)
except Exception as e:
log.debug("[#%04X] C: <OPEN FAILED> %r", s.getsockname()[1], e)
log.debug("[#%04X] C: <OPEN FAILED> %r", connection.local_port, e)
connection.kill()
raise

Expand Down

0 comments on commit c468106

Please sign in to comment.