diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index d4ce3b17d8..b675b35768 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -410,6 +410,15 @@ Socket objects `Not implemented yet! `__ + We also keep track of an extra bit of state, because it turns out + to be useful for :class:`trio.SocketStream`: + + .. attribute:: did_shutdown_SHUT_WR + + This :class:`bool` attribute it True if you've called + ``sock.shutdown(SHUT_WR)`` or ``sock.shutdown(SHUT_RDWR)``, and + False otherwise. + The following methods are identical to their equivalents in :func:`socket.socket`, except async, and the ones that take address arguments require pre-resolved addresses: diff --git a/trio/_network.py b/trio/_network.py index ee5f80c0df..4e33b7ebc1 100644 --- a/trio/_network.py +++ b/trio/_network.py @@ -94,7 +94,7 @@ def __init__(self, sock): pass async def send_all(self, data): - if self.socket._did_SHUT_WR: + if self.socket.did_shutdown_SHUT_WR: await _core.yield_briefly() raise ClosedStreamError("can't send data after sending EOF") with self._send_lock.sync: @@ -112,7 +112,7 @@ async def send_eof(self): async with self._send_lock: # On MacOS, calling shutdown a second time raises ENOTCONN, but # send_eof needs to be idempotent. - if self.socket._did_SHUT_WR: + if self.socket.did_shutdown_SHUT_WR: return with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) diff --git a/trio/socket.py b/trio/socket.py index 58e7f0a889..ed75a58736 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -246,7 +246,7 @@ def __init__(self, sock): .format(type(sock).__name__)) self._sock = sock self._sock.setblocking(False) - self._did_SHUT_WR = False + self._did_shutdown_SHUT_WR = False # Defaults: if self._sock.family == AF_INET6: @@ -307,6 +307,10 @@ def type(self): def proto(self): return self._sock.proto + @property + def did_shutdown_SHUT_WR(self): + return self._did_shutdown_SHUT_WR + def __repr__(self): return repr(self._sock).replace("socket.socket", "trio.socket.socket") @@ -325,7 +329,7 @@ def shutdown(self, flag): self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [SHUT_WR, SHUT_RDWR]: - self._did_SHUT_WR = True + self._did_shutdown_SHUT_WR = True async def wait_writable(self): await _core.wait_socket_writable(self._sock) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index 9f91298208..3b67689093 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -323,11 +323,20 @@ async def test_SocketType_shutdown(): with a, b: await a.sendall(b"xxx") assert await b.recv(3) == b"xxx" + assert not a.did_shutdown_SHUT_WR + assert not b.did_shutdown_SHUT_WR a.shutdown(tsocket.SHUT_WR) + assert a.did_shutdown_SHUT_WR + assert not b.did_shutdown_SHUT_WR assert await b.recv(3) == b"" await b.sendall(b"yyy") assert await a.recv(3) == b"yyy" + b.shutdown(tsocket.SHUT_RD) + assert not b.did_shutdown_SHUT_WR + b.shutdown(tsocket.SHUT_RDWR) + assert b.did_shutdown_SHUT_WR + @pytest.mark.parametrize("address, socket_type", [('127.0.0.1', tsocket.AF_INET), ('::1', tsocket.AF_INET6)]) async def test_SocketType_simple_server(address, socket_type):