Skip to content

Commit

Permalink
Unpickle Events, Variables, Queues and Semaphore safely without Clien…
Browse files Browse the repository at this point in the history
…t context (#7579)
  • Loading branch information
fjetter authored Feb 23, 2023
1 parent 41fdb91 commit 790d852
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 54 deletions.
20 changes: 15 additions & 5 deletions distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

from dask.utils import parse_timedelta

from distributed.client import Client
from distributed.utils import TimeoutError, log_errors
from distributed.worker import get_worker
from distributed.worker import get_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,10 +179,9 @@ class Event:

def __init__(self, name=None, client=None):
try:
self.client = client or Client.current()
self.client = client or get_client()
except ValueError:
# Initialise new client
self.client = get_worker().client
self.client = None
self.name = name or "event-" + uuid.uuid4().hex

def __await__(self):
Expand All @@ -201,6 +199,14 @@ async def _():

return _().__await__()

def _verify_running(self):
if not self.client:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen"
" if the object is being deserialized outside of the context of"
" a Client or Worker."
)

def wait(self, timeout=None):
"""Wait until the event is set.
Expand All @@ -221,6 +227,7 @@ def wait(self, timeout=None):
-------
True if the event was set of false, if a timeout happened
"""
self._verify_running()
timeout = parse_timedelta(timeout)

result = self.client.sync(
Expand All @@ -233,18 +240,21 @@ def clear(self):
All waiters will now block.
"""
self._verify_running()
return self.client.sync(self.client.scheduler.event_clear, name=self.name)

def set(self):
"""Set the event (set its flag to false).
All waiters will now be released.
"""
self._verify_running()
result = self.client.sync(self.client.scheduler.event_set, name=self.name)
return result

def is_set(self):
"""Check if the event is set"""
self._verify_running()
result = self.client.sync(self.client.scheduler.event_is_set, name=self.name)
return result

Expand Down
44 changes: 23 additions & 21 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from dask.utils import parse_timedelta, stringify

from distributed.client import Client, Future
from distributed.worker import get_client, get_worker
from distributed.client import Future
from distributed.worker import get_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,17 +167,24 @@ class Queue:

def __init__(self, name=None, client=None, maxsize=0):
try:
self.client = client or Client.current()
self.client = client or get_client()
except ValueError:
# Initialise new client
self.client = get_worker().client
self.client = None
self.name = name or "queue-" + uuid.uuid4().hex
self.maxsize = maxsize

if self.client.asynchronous:
self._started = asyncio.ensure_future(self._start())
else:
self.client.sync(self._start)
if self.client:
if self.client.asynchronous:
self._started = asyncio.ensure_future(self._start())
else:
self.client.sync(self._start)

def _verify_running(self):
if not self.client:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen"
" if the object is being deserialized outside of the context of"
" a Client or Worker."
)

async def _start(self):
await self.client.scheduler.queue_create(name=self.name, maxsize=self.maxsize)
Expand Down Expand Up @@ -213,6 +220,7 @@ def put(self, value, timeout=None, **kwargs):
Instead of number of seconds, it is also possible to specify
a timedelta in string format, e.g. "200ms".
"""
self._verify_running()
timeout = parse_timedelta(timeout)
return self.client.sync(self._put, value, timeout=timeout, **kwargs)

Expand All @@ -230,11 +238,13 @@ def get(self, timeout=None, batch=False, **kwargs):
If an integer than return that many elements from the queue
If False (default) then return one item at a time
"""
self._verify_running()
timeout = parse_timedelta(timeout)
return self.client.sync(self._get, timeout=timeout, batch=batch, **kwargs)

def qsize(self, **kwargs):
"""Current number of elements in the queue"""
self._verify_running()
return self.client.sync(self._qsize, **kwargs)

async def _get(self, timeout=None, batch=False):
Expand Down Expand Up @@ -267,17 +277,9 @@ async def _qsize(self):
return result

def close(self):
self._verify_running()
if self.client.status == "running": # TODO: can leave zombie futures
self.client._send_to_scheduler({"op": "queue_release", "name": self.name})

def __getstate__(self):
return (self.name, self.client.scheduler.address)

def __setstate__(self, state):
name, address = state
try:
client = get_client(address)
assert client.scheduler.address == address
except (AttributeError, AssertionError):
client = Client(address, set_as_default=False)
self.__init__(name=name, client=client)
def __reduce__(self):
return type(self), (self.name, None, self.maxsize)
35 changes: 26 additions & 9 deletions distributed/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,15 +345,19 @@ def __init__(
loop=None,
):
try:
worker = get_worker()
self.scheduler = scheduler_rpc or worker.scheduler
self.loop = loop or worker.loop

try:
worker = get_worker()
self.scheduler = scheduler_rpc or worker.scheduler
self.loop = loop or worker.loop

except ValueError:
client = get_client()
self.scheduler = scheduler_rpc or client.scheduler
self.loop = loop or client.loop
except ValueError:
client = get_client()
self.scheduler = scheduler_rpc or client.scheduler
self.loop = loop or client.loop

# This happens if this is deserialized on the scheduler
self.scheduler = None
self.loop = None
self.name = name or "semaphore-" + uuid.uuid4().hex
self.max_leases = max_leases
self.id = uuid.uuid4().hex
Expand Down Expand Up @@ -381,7 +385,14 @@ def __init__(

# Need to start the callback using IOLoop.add_callback to ensure that the
# PC uses the correct event loop.
self.loop.add_callback(pc.start)
if self.loop is not None:
self.loop.add_callback(pc.start)

def _verify_running(self):
if not self.scheduler or not self.loop:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen if the object is being deserialized outside of the context of a Client or Worker."
)

async def _register(self):
await retry_operation(
Expand Down Expand Up @@ -453,6 +464,7 @@ def acquire(self, timeout=None):
Instead of number of seconds, it is also possible to specify
a timedelta in string format, e.g. "200ms".
"""
self._verify_running()
timeout = parse_timedelta(timeout)
return self.sync(self._acquire, timeout=timeout)

Expand Down Expand Up @@ -486,6 +498,7 @@ def release(self):
immediately, but it will always be automatically released after a specific interval configured using
"distributed.scheduler.locks.lease-validation-interval" and "distributed.scheduler.locks.lease-timeout".
"""
self._verify_running()
if not self._leases:
raise RuntimeError("Released too often")

Expand All @@ -498,16 +511,19 @@ def get_value(self):
"""
Return the number of currently registered leases.
"""
self._verify_running()
return self.sync(self.scheduler.semaphore_value, name=self.name)

def __enter__(self):
self._verify_running()
self.acquire()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.release()

async def __aenter__(self):
self._verify_running()
await self.acquire()
return self

Expand All @@ -528,6 +544,7 @@ def __setstate__(self, state):
)

def close(self):
self._verify_running()
self.refresh_callback.stop()
return self.sync(self.scheduler.semaphore_close, name=self.name)

Expand Down
34 changes: 33 additions & 1 deletion distributed/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pickle
from datetime import timedelta

from distributed import Event
import pytest

from distributed import Client, Event
from distributed.utils_test import gen_cluster


Expand Down Expand Up @@ -220,3 +222,33 @@ def event_is_set(event_name):

assert not s.extensions["events"]._events
assert not s.extensions["events"]._waiter_count


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
obj = await Event()
pickled = pickle.dumps(obj)
await c.close()

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Client.current()

obj2 = pickle.loads(pickled)

with pytest.raises(ValueError):
Client.current()

assert obj2.client is None

with pytest.raises(RuntimeError, match="not properly initialized"):
await obj2.set()

async with Client(s.address, asynchronous=True):
obj3 = pickle.loads(pickled)
await obj3.set()
await obj3.wait()
33 changes: 32 additions & 1 deletion distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import pickle
from datetime import timedelta
from time import sleep

Expand Down Expand Up @@ -86,7 +87,6 @@ async def test_hold_futures(s, a, b):
assert result == 11


@pytest.mark.skip(reason="getting same client from main thread")
@gen_cluster(client=True)
async def test_picklability(c, s, a, b):
q = Queue()
Expand Down Expand Up @@ -301,3 +301,34 @@ def foo():

result = c.submit(foo).result()
assert result == 123


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
q = await Queue()
pickled = pickle.dumps(q)
await c.close()

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Client.current()

q2 = pickle.loads(pickled)

with pytest.raises(ValueError):
Client.current()

assert q2.client is None
await asyncio.sleep(0)

with pytest.raises(RuntimeError, match="not properly initialized"):
await q2.put(1)

async with Client(s.address, asynchronous=True):
q3 = pickle.loads(pickled)
await q3.put(1)
assert await q3.get() == 1
31 changes: 31 additions & 0 deletions distributed/tests/test_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,34 @@ async def test_release_failure(c, s, a, b, caplog):
assert ext.get_value(name) == 1 # lease is still registered
while not (await semaphore.get_value() == 0):
await asyncio.sleep(0.01)


@gen_cluster(client=True, nthreads=[])
async def test_unpickle_without_client(c, s):
"""Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.
This typically happens if the object is being deserialized on the scheduler.
"""
sem = await Semaphore()
pickled = pickle.dumps(sem)
await c.close()

# We do not want to initialize a client during unpickling
with pytest.raises(ValueError):
Client.current()

s2 = pickle.loads(pickled)

with pytest.raises(ValueError):
Client.current()

assert s2.scheduler is None
await asyncio.sleep(0)
assert not s2.refresh_callback.is_running()

with pytest.raises(RuntimeError, match="not properly initialized"):
await s2.acquire()

async with Client(s.address, asynchronous=True):
s3 = pickle.loads(pickled)
assert await s3.acquire()
Loading

0 comments on commit 790d852

Please sign in to comment.