Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor UnLock -> ConflictDetector #295

Merged
merged 1 commit into from
Aug 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import _core
from . import socket as tsocket
from ._socket import real_socket_type
from ._util import UnLock
from ._util import ConflictDetector
from .abc import HalfCloseableStream, Listener
from ._highlevel_generic import (
ClosedStreamError, BrokenStreamError, ClosedListenerError
Expand Down Expand Up @@ -71,8 +71,7 @@ def __init__(self, socket):
raise err from None

self.socket = socket
self._send_lock = UnLock(
_core.ResourceBusyError,
self._send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SocketStream"
)

Expand Down Expand Up @@ -105,19 +104,19 @@ async def send_all(self, data):
if self.socket.did_shutdown_SHUT_WR:
await _core.yield_briefly()
raise ClosedStreamError("can't send data after sending EOF")
with self._send_lock.sync:
with self._send_conflict_detector.sync:
with _translate_socket_errors_to_stream_errors():
await self.socket.sendall(data)

async def wait_send_all_might_not_block(self):
async with self._send_lock:
async with self._send_conflict_detector:
if self.socket.fileno() == -1:
raise ClosedStreamError
with _translate_socket_errors_to_stream_errors():
await self.socket.wait_writable()

async def send_eof(self):
async with self._send_lock:
async with self._send_conflict_detector:
# On MacOS, calling shutdown a second time raises ENOTCONN, but
# send_eof needs to be idempotent.
if self.socket.did_shutdown_SHUT_WR:
Expand Down
17 changes: 8 additions & 9 deletions trio/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
BrokenStreamError, ClosedStreamError, aclose_forcefully
)
from . import _sync
from ._util import UnLock
from ._util import ConflictDetector

__all__ = ["SSLStream", "SSLListener"]

Expand Down Expand Up @@ -368,12 +368,10 @@ def __init__(
# These are used to make sure that our caller doesn't attempt to make
# multiple concurrent calls to send_all/wait_send_all_might_not_block
# or to receive_some.
self._outer_send_lock = UnLock(
_core.ResourceBusyError,
self._outer_send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SSLStream"
)
self._outer_recv_lock = UnLock(
_core.ResourceBusyError,
self._outer_recv_conflict_detector = ConflictDetector(
"another task is currently receiving data on this SSLStream"
)

Expand Down Expand Up @@ -624,7 +622,7 @@ async def receive_some(self, max_bytes):
:exc:`trio.BrokenStreamError`.

"""
async with self._outer_recv_lock:
async with self._outer_recv_conflict_detector:
self._check_status()
try:
await self._handshook.ensure(checkpoint=False)
Expand Down Expand Up @@ -666,7 +664,7 @@ async def send_all(self, data):
:exc:`trio.BrokenStreamError`.

"""
async with self._outer_send_lock:
async with self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
# SSLObject interprets write(b"") as an EOF for some reason, which
Expand All @@ -693,7 +691,8 @@ async def unwrap(self):
``transport_stream.receive_some(...)``.

"""
async with self._outer_recv_lock, self._outer_send_lock:
async with self._outer_recv_conflict_detector, \
self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
await self._retry(self._ssl_object.unwrap)
Expand Down Expand Up @@ -797,7 +796,7 @@ async def wait_send_all_might_not_block(self):
# semantics that wait_send_all_might_not_block and send_all
# conflict. This also takes care of providing correct checkpoint
# semantics before we potentially error out from _check_status().
async with self._outer_send_lock:
async with self._outer_send_conflict_detector:
self._check_status()
# Then we take the inner send lock. We know that no other tasks
# are calling self.send_all or self.wait_send_all_might_not_block,
Expand Down
29 changes: 16 additions & 13 deletions trio/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

# There's a dependency loop here... _core is allowed to use this file (in fact
# it's the *only* file in the main trio/ package it's allowed to use), but
# UnLock needs yield_briefly so it also has to import _core. Possibly we
# should split this file into two: one for true generic low-level utility
# code, and one for higher level helpers?
# ConflictDetector needs yield_briefly so it also has to import
# _core. Possibly we should split this file into two: one for true generic
# low-level utility code, and one for higher level helpers?
from . import _core

__all__ = [
"signal_raise",
"aiter_compat",
"acontextmanager",
"UnLock",
"ConflictDetector",
"fixup_module_metadata",
]

Expand Down Expand Up @@ -176,24 +176,24 @@ def helper(*args, **kwds):
return helper


class _UnLockSync:
def __init__(self, exc, *args):
self._exc = exc
self._args = args
class _ConflictDetectorSync:
def __init__(self, msg):
self._msg = msg
self._held = False

def __enter__(self):
if self._held:
raise self._exc(*self._args)
raise _core.ResourceBusyError(self._msg)
else:
self._held = True

def __exit__(self, *args):
self._held = False


class UnLock:
"""An unnecessary lock.
class ConflictDetector:
"""Detect when two tasks are about to perform operations that would
conflict.

Use as an async context manager; if two tasks enter it at the same
time then the second one raises an error. You can use it when there are
Expand All @@ -205,10 +205,13 @@ class UnLock:

This executes a checkpoint on entry. That's the only reason it's async.

To use from sync code, do ``with cd.sync``; this is just like ``async with
cd`` except that it doesn't execute a checkpoint.

"""

def __init__(self, exc, *args):
self.sync = _UnLockSync(exc, *args)
def __init__(self, msg):
self.sync = _ConflictDetectorSync(msg)

async def __aenter__(self):
await _core.yield_briefly()
Expand Down
32 changes: 16 additions & 16 deletions trio/testing/_memory_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self):
self._data = bytearray()
self._closed = False
self._lot = _core.ParkingLot()
self._fetch_lock = _util.UnLock(
_core.ResourceBusyError, "another task is already fetching data"
self._fetch_lock = _util.ConflictDetector(
"another task is already fetching data"
)

def close(self):
Expand Down Expand Up @@ -102,8 +102,8 @@ def __init__(
wait_send_all_might_not_block_hook=None,
close_hook=None
):
self._lock = _util.UnLock(
_core.ResourceBusyError, "another task is using this stream"
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream"
)
self._outgoing = _UnboundedByteQueue()
self.send_all_hook = send_all_hook
Expand All @@ -118,7 +118,7 @@ async def send_all(self, data):
# The lock itself is a checkpoint, but then we also yield inside the
# lock to give ourselves a chance to detect buggy user code that calls
# this twice at the same time.
async with self._lock:
async with self._conflict_detector:
await _core.yield_briefly()
self._outgoing.put(data)
if self.send_all_hook is not None:
Expand All @@ -132,7 +132,7 @@ async def wait_send_all_might_not_block(self):
# The lock itself is a checkpoint, but then we also yield inside the
# lock to give ourselves a chance to detect buggy user code that calls
# this twice at the same time.
async with self._lock:
async with self._conflict_detector:
await _core.yield_briefly()
# check for being closed:
self._outgoing.put(b"")
Expand Down Expand Up @@ -201,8 +201,8 @@ class MemoryReceiveStream(ReceiveStream):
"""

def __init__(self, receive_some_hook=None, close_hook=None):
self._lock = _util.UnLock(
_core.ResourceBusyError, "another task is using this stream"
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream"
)
self._incoming = _UnboundedByteQueue()
self._closed = False
Expand All @@ -217,7 +217,7 @@ async def receive_some(self, max_bytes):
# The lock itself is a checkpoint, but then we also yield inside the
# lock to give ourselves a chance to detect buggy user code that calls
# this twice at the same time.
async with self._lock:
async with self._conflict_detector:
await _core.yield_briefly()
if max_bytes is None:
raise TypeError("max_bytes must not be None")
Expand Down Expand Up @@ -435,11 +435,11 @@ def __init__(self):
self._receiver_closed = False
self._receiver_waiting = False
self._waiters = _core.ParkingLot()
self._send_lock = _util.UnLock(
_core.ResourceBusyError, "another task is already sending"
self._send_conflict_detector = _util.ConflictDetector(
"another task is already sending"
)
self._receive_lock = _util.UnLock(
_core.ResourceBusyError, "another task is already receiving"
self._receive_conflict_detector = _util.ConflictDetector(
"another task is already receiving"
)

def _something_happened(self):
Expand All @@ -459,7 +459,7 @@ def close_receiver(self):
self._something_happened()

async def send_all(self, data):
async with self._send_lock:
async with self._send_conflict_detector:
if self._sender_closed:
raise ClosedStreamError
if self._receiver_closed:
Expand All @@ -476,7 +476,7 @@ async def send_all(self, data):
return

async def wait_send_all_might_not_block(self):
async with self._send_lock:
async with self._send_conflict_detector:
if self._sender_closed:
raise ClosedStreamError
if self._receiver_closed:
Expand All @@ -486,7 +486,7 @@ async def wait_send_all_might_not_block(self):
)

async def receive_some(self, max_bytes):
async with self._receive_lock:
async with self._receive_conflict_detector:
# Argument validation
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
Expand Down
14 changes: 6 additions & 8 deletions trio/tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .._highlevel_open_tcp_stream import open_tcp_stream
from .. import ssl as tssl
from .. import socket as tsocket
from .._util import UnLock, acontextmanager
from .._util import ConflictDetector, acontextmanager

from .._core.tests.tutil import slow

Expand Down Expand Up @@ -175,12 +175,10 @@ def __init__(self, sleeper=None):
self._lot = _core.ParkingLot()
self._pending_cleartext = bytearray()

self._send_all_mutex = UnLock(
_core.ResourceBusyError,
self._send_all_conflict_detector = ConflictDetector(
"simultaneous calls to PyOpenSSLEchoStream.send_all"
)
self._receive_some_mutex = UnLock(
_core.ResourceBusyError,
self._receive_some_conflict_detector = ConflictDetector(
"simultaneous calls to PyOpenSSLEchoStream.receive_some"
)

Expand All @@ -205,13 +203,13 @@ def renegotiate(self):
assert self._conn.renegotiate()

async def wait_send_all_might_not_block(self):
async with self._send_all_mutex:
async with self._send_all_conflict_detector:
await _core.yield_briefly()
await self.sleeper("wait_send_all_might_not_block")

async def send_all(self, data):
print(" --> transport_stream.send_all")
async with self._send_all_mutex:
async with self._send_all_conflict_detector:
await _core.yield_briefly()
await self.sleeper("send_all")
self._conn.bio_write(data)
Expand All @@ -233,7 +231,7 @@ async def send_all(self, data):

async def receive_some(self, nbytes):
print(" --> transport_stream.receive_some")
async with self._receive_some_mutex:
async with self._receive_some_conflict_detector:
try:
await _core.yield_briefly()
while True:
Expand Down
18 changes: 6 additions & 12 deletions trio/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,34 @@ def handler(signum, _):
assert record == [signal.SIGFPE]


async def test_UnLock():
ul1 = UnLock(RuntimeError, "ul1")
ul2 = UnLock(ValueError)
async def test_ConflictDetector():
ul1 = ConflictDetector("ul1")
ul2 = ConflictDetector("ul2")

async with ul1:
with assert_yields():
async with ul2:
print("ok")

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(_core.ResourceBusyError) as excinfo:
async with ul1:
with assert_yields():
async with ul1:
pass # pragma: no cover
assert "ul1" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
async with ul2:
with assert_yields():
async with ul2:
pass # pragma: no cover

async def wait_with_ul1():
async with ul1:
await wait_all_tasks_blocked()

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(_core.ResourceBusyError) as excinfo:
async with _core.open_nursery() as nursery:
nursery.spawn(wait_with_ul1)
nursery.spawn(wait_with_ul1)
assert "ul1" in str(excinfo.value)

# mixing sync and async entry
with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(_core.ResourceBusyError) as excinfo:
with ul1.sync:
with assert_yields():
async with ul1:
Expand Down