Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove quasi-public APIs from trio socket interface #249

Merged
merged 2 commits into from
Jul 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,15 @@ Socket objects

`Not implemented yet! <https://github.com/python-trio/trio/issues/45>`__

We also keep track of an extra bit of state, because it turns out
to be useful for :class:`trio.SocketStream`:

.. attribute:: did_shutdown_SHUT_WR

This :class:`bool` attribute it True if you've called
``sock.shutdown(SHUT_WR)`` or ``sock.shutdown(SHUT_RDWR)``, and
False otherwise.

The following methods are identical to their equivalents in
:func:`socket.socket`, except async, and the ones that take address
arguments require pre-resolved addresses:
Expand Down
6 changes: 3 additions & 3 deletions trio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SocketStream(HalfCloseableStream):
def __init__(self, sock):
if not tsocket.is_trio_socket(sock):
raise TypeError("SocketStream requires trio socket object")
if sock._real_type != tsocket.SOCK_STREAM:
if tsocket._real_type(sock.type) != tsocket.SOCK_STREAM:
raise ValueError("SocketStream requires a SOCK_STREAM socket")
try:
sock.getpeername()
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(self, sock):
pass

async def send_all(self, data):
if self.socket._did_SHUT_WR:
if self.socket.did_shutdown_SHUT_WR:
await _core.yield_briefly()
raise ClosedStreamError("can't send data after sending EOF")
with self._send_lock.sync:
Expand All @@ -112,7 +112,7 @@ async def send_eof(self):
async with self._send_lock:
# On MacOS, calling shutdown a second time raises ENOTCONN, but
# send_eof needs to be idempotent.
if self.socket._did_SHUT_WR:
if self.socket.did_shutdown_SHUT_WR:
return
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
Expand Down
25 changes: 15 additions & 10 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ def is_trio_socket(obj):
getattr(_stdlib_socket, "SOCK_NONBLOCK", 0)
| getattr(_stdlib_socket, "SOCK_CLOEXEC", 0))

# Hopefully Python will eventually make something like this public
# (see bpo-21327) but I don't want to make it public myself and then
# find out they picked a different name... this is used internally in
# this file and also elsewhere in trio.
def _real_type(type_num):
return type_num & _SOCK_TYPE_MASK

class _SocketType:
def __init__(self, sock):
if type(sock) is not _stdlib_socket.socket:
Expand All @@ -239,13 +246,7 @@ def __init__(self, sock):
.format(type(sock).__name__))
self._sock = sock
self._sock.setblocking(False)
self._did_SHUT_WR = False

# Hopefully Python will eventually make something like this public
# (see bpo-21327) but I don't want to make it public myself and then
# find out they picked a different name... this is used internally in
# this file and also elsewhere in trio.
self._real_type = sock.type & _SOCK_TYPE_MASK
self._did_shutdown_SHUT_WR = False

# Defaults:
if self._sock.family == AF_INET6:
Expand Down Expand Up @@ -306,6 +307,10 @@ def type(self):
def proto(self):
return self._sock.proto

@property
def did_shutdown_SHUT_WR(self):
return self._did_shutdown_SHUT_WR

def __repr__(self):
return repr(self._sock).replace("socket.socket", "trio.socket.socket")

Expand All @@ -324,7 +329,7 @@ def shutdown(self, flag):
self._sock.shutdown(flag)
# only do this if the call succeeded:
if flag in [SHUT_WR, SHUT_RDWR]:
self._did_SHUT_WR = True
self._did_shutdown_SHUT_WR = True

async def wait_writable(self):
await _core.wait_socket_writable(self._sock)
Expand Down Expand Up @@ -362,7 +367,7 @@ def _check_address(self, address, *, require_resolved):
_stdlib_socket.getaddrinfo(
address[0], address[1],
self._sock.family,
self._real_type,
_real_type(self._sock.type),
self._sock.proto,
flags=_NUMERIC_ONLY)
except gaierror as exc:
Expand Down Expand Up @@ -399,7 +404,7 @@ async def _resolve_address(self, address, flags):
gai_res = await getaddrinfo(
address[0], address[1],
self._sock.family,
self._real_type,
_real_type(self._sock.type),
self._sock.proto,
flags)
# AFAICT from the spec it's not possible for getaddrinfo to return an
Expand Down
16 changes: 16 additions & 0 deletions trio/tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,27 @@ async def test_SocketType_shutdown():
with a, b:
await a.sendall(b"xxx")
assert await b.recv(3) == b"xxx"
assert not a.did_shutdown_SHUT_WR
assert not b.did_shutdown_SHUT_WR
a.shutdown(tsocket.SHUT_WR)
assert a.did_shutdown_SHUT_WR
assert not b.did_shutdown_SHUT_WR
assert await b.recv(3) == b""
await b.sendall(b"yyy")
assert await a.recv(3) == b"yyy"

a, b = tsocket.socketpair()
with a, b:
assert not a.did_shutdown_SHUT_WR
a.shutdown(tsocket.SHUT_RD)
assert not a.did_shutdown_SHUT_WR

a, b = tsocket.socketpair()
with a, b:
assert not a.did_shutdown_SHUT_WR
a.shutdown(tsocket.SHUT_RDWR)
assert a.did_shutdown_SHUT_WR


@pytest.mark.parametrize("address, socket_type", [('127.0.0.1', tsocket.AF_INET), ('::1', tsocket.AF_INET6)])
async def test_SocketType_simple_server(address, socket_type):
Expand Down