From 790d8522b97635de1d85a4414031ab31ce12a089 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 23 Feb 2023 15:55:11 +0100 Subject: [PATCH] Unpickle Events, Variables, Queues and Semaphore safely without Client context (#7579) --- distributed/event.py | 20 +++++++++---- distributed/queues.py | 44 +++++++++++++++-------------- distributed/semaphore.py | 35 +++++++++++++++++------ distributed/tests/test_events.py | 34 +++++++++++++++++++++- distributed/tests/test_queues.py | 33 +++++++++++++++++++++- distributed/tests/test_semaphore.py | 31 ++++++++++++++++++++ distributed/tests/test_variable.py | 31 ++++++++++++++++++++ distributed/variable.py | 35 ++++++++++++----------- 8 files changed, 209 insertions(+), 54 deletions(-) diff --git a/distributed/event.py b/distributed/event.py index 02b2a44179..f09b9c8980 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -8,9 +8,8 @@ from dask.utils import parse_timedelta -from distributed.client import Client from distributed.utils import TimeoutError, log_errors -from distributed.worker import get_worker +from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -180,10 +179,9 @@ class Event: def __init__(self, name=None, client=None): try: - self.client = client or Client.current() + self.client = client or get_client() except ValueError: - # Initialise new client - self.client = get_worker().client + self.client = None self.name = name or "event-" + uuid.uuid4().hex def __await__(self): @@ -201,6 +199,14 @@ async def _(): return _().__await__() + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) + def wait(self, timeout=None): """Wait until the event is set. @@ -221,6 +227,7 @@ def wait(self, timeout=None): ------- True if the event was set of false, if a timeout happened """ + self._verify_running() timeout = parse_timedelta(timeout) result = self.client.sync( @@ -233,6 +240,7 @@ def clear(self): All waiters will now block. """ + self._verify_running() return self.client.sync(self.client.scheduler.event_clear, name=self.name) def set(self): @@ -240,11 +248,13 @@ def set(self): All waiters will now be released. """ + self._verify_running() result = self.client.sync(self.client.scheduler.event_set, name=self.name) return result def is_set(self): """Check if the event is set""" + self._verify_running() result = self.client.sync(self.client.scheduler.event_is_set, name=self.name) return result diff --git a/distributed/queues.py b/distributed/queues.py index 668b04e89d..aca3e026d9 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -7,8 +7,8 @@ from dask.utils import parse_timedelta, stringify -from distributed.client import Client, Future -from distributed.worker import get_client, get_worker +from distributed.client import Future +from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -167,17 +167,24 @@ class Queue: def __init__(self, name=None, client=None, maxsize=0): try: - self.client = client or Client.current() + self.client = client or get_client() except ValueError: - # Initialise new client - self.client = get_worker().client + self.client = None self.name = name or "queue-" + uuid.uuid4().hex self.maxsize = maxsize - - if self.client.asynchronous: - self._started = asyncio.ensure_future(self._start()) - else: - self.client.sync(self._start) + if self.client: + if self.client.asynchronous: + self._started = asyncio.ensure_future(self._start()) + else: + self.client.sync(self._start) + + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) async def _start(self): await self.client.scheduler.queue_create(name=self.name, maxsize=self.maxsize) @@ -213,6 +220,7 @@ def put(self, value, timeout=None, **kwargs): Instead of number of seconds, it is also possible to specify a timedelta in string format, e.g. "200ms". """ + self._verify_running() timeout = parse_timedelta(timeout) return self.client.sync(self._put, value, timeout=timeout, **kwargs) @@ -230,11 +238,13 @@ def get(self, timeout=None, batch=False, **kwargs): If an integer than return that many elements from the queue If False (default) then return one item at a time """ + self._verify_running() timeout = parse_timedelta(timeout) return self.client.sync(self._get, timeout=timeout, batch=batch, **kwargs) def qsize(self, **kwargs): """Current number of elements in the queue""" + self._verify_running() return self.client.sync(self._qsize, **kwargs) async def _get(self, timeout=None, batch=False): @@ -267,17 +277,9 @@ async def _qsize(self): return result def close(self): + self._verify_running() if self.client.status == "running": # TODO: can leave zombie futures self.client._send_to_scheduler({"op": "queue_release", "name": self.name}) - def __getstate__(self): - return (self.name, self.client.scheduler.address) - - def __setstate__(self, state): - name, address = state - try: - client = get_client(address) - assert client.scheduler.address == address - except (AttributeError, AssertionError): - client = Client(address, set_as_default=False) - self.__init__(name=name, client=client) + def __reduce__(self): + return type(self), (self.name, None, self.maxsize) diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 8e1d782758..7d22c596c6 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -345,15 +345,19 @@ def __init__( loop=None, ): try: - worker = get_worker() - self.scheduler = scheduler_rpc or worker.scheduler - self.loop = loop or worker.loop - + try: + worker = get_worker() + self.scheduler = scheduler_rpc or worker.scheduler + self.loop = loop or worker.loop + + except ValueError: + client = get_client() + self.scheduler = scheduler_rpc or client.scheduler + self.loop = loop or client.loop except ValueError: - client = get_client() - self.scheduler = scheduler_rpc or client.scheduler - self.loop = loop or client.loop - + # This happens if this is deserialized on the scheduler + self.scheduler = None + self.loop = None self.name = name or "semaphore-" + uuid.uuid4().hex self.max_leases = max_leases self.id = uuid.uuid4().hex @@ -381,7 +385,14 @@ def __init__( # Need to start the callback using IOLoop.add_callback to ensure that the # PC uses the correct event loop. - self.loop.add_callback(pc.start) + if self.loop is not None: + self.loop.add_callback(pc.start) + + def _verify_running(self): + if not self.scheduler or not self.loop: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen if the object is being deserialized outside of the context of a Client or Worker." + ) async def _register(self): await retry_operation( @@ -453,6 +464,7 @@ def acquire(self, timeout=None): Instead of number of seconds, it is also possible to specify a timedelta in string format, e.g. "200ms". """ + self._verify_running() timeout = parse_timedelta(timeout) return self.sync(self._acquire, timeout=timeout) @@ -486,6 +498,7 @@ def release(self): immediately, but it will always be automatically released after a specific interval configured using "distributed.scheduler.locks.lease-validation-interval" and "distributed.scheduler.locks.lease-timeout". """ + self._verify_running() if not self._leases: raise RuntimeError("Released too often") @@ -498,9 +511,11 @@ def get_value(self): """ Return the number of currently registered leases. """ + self._verify_running() return self.sync(self.scheduler.semaphore_value, name=self.name) def __enter__(self): + self._verify_running() self.acquire() return self @@ -508,6 +523,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.release() async def __aenter__(self): + self._verify_running() await self.acquire() return self @@ -528,6 +544,7 @@ def __setstate__(self, state): ) def close(self): + self._verify_running() self.refresh_callback.stop() return self.sync(self.scheduler.semaphore_close, name=self.name) diff --git a/distributed/tests/test_events.py b/distributed/tests/test_events.py index 1389a38458..5d6f4bbe84 100644 --- a/distributed/tests/test_events.py +++ b/distributed/tests/test_events.py @@ -3,7 +3,9 @@ import pickle from datetime import timedelta -from distributed import Event +import pytest + +from distributed import Client, Event from distributed.utils_test import gen_cluster @@ -220,3 +222,33 @@ def event_is_set(event_name): assert not s.extensions["events"]._events assert not s.extensions["events"]._waiter_count + + +@gen_cluster(client=True, nthreads=[]) +async def test_unpickle_without_client(c, s): + """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. + + This typically happens if the object is being deserialized on the scheduler. + """ + obj = await Event() + pickled = pickle.dumps(obj) + await c.close() + + # We do not want to initialize a client during unpickling + with pytest.raises(ValueError): + Client.current() + + obj2 = pickle.loads(pickled) + + with pytest.raises(ValueError): + Client.current() + + assert obj2.client is None + + with pytest.raises(RuntimeError, match="not properly initialized"): + await obj2.set() + + async with Client(s.address, asynchronous=True): + obj3 = pickle.loads(pickled) + await obj3.set() + await obj3.wait() diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index bd33a5f410..abbc1cf3ea 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import pickle from datetime import timedelta from time import sleep @@ -86,7 +87,6 @@ async def test_hold_futures(s, a, b): assert result == 11 -@pytest.mark.skip(reason="getting same client from main thread") @gen_cluster(client=True) async def test_picklability(c, s, a, b): q = Queue() @@ -301,3 +301,34 @@ def foo(): result = c.submit(foo).result() assert result == 123 + + +@gen_cluster(client=True, nthreads=[]) +async def test_unpickle_without_client(c, s): + """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. + + This typically happens if the object is being deserialized on the scheduler. + """ + q = await Queue() + pickled = pickle.dumps(q) + await c.close() + + # We do not want to initialize a client during unpickling + with pytest.raises(ValueError): + Client.current() + + q2 = pickle.loads(pickled) + + with pytest.raises(ValueError): + Client.current() + + assert q2.client is None + await asyncio.sleep(0) + + with pytest.raises(RuntimeError, match="not properly initialized"): + await q2.put(1) + + async with Client(s.address, asynchronous=True): + q3 = pickle.loads(pickled) + await q3.put(1) + assert await q3.get() == 1 diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 16adaae2fd..a855a63f4f 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -575,3 +575,34 @@ async def test_release_failure(c, s, a, b, caplog): assert ext.get_value(name) == 1 # lease is still registered while not (await semaphore.get_value() == 0): await asyncio.sleep(0.01) + + +@gen_cluster(client=True, nthreads=[]) +async def test_unpickle_without_client(c, s): + """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. + + This typically happens if the object is being deserialized on the scheduler. + """ + sem = await Semaphore() + pickled = pickle.dumps(sem) + await c.close() + + # We do not want to initialize a client during unpickling + with pytest.raises(ValueError): + Client.current() + + s2 = pickle.loads(pickled) + + with pytest.raises(ValueError): + Client.current() + + assert s2.scheduler is None + await asyncio.sleep(0) + assert not s2.refresh_callback.is_running() + + with pytest.raises(RuntimeError, match="not properly initialized"): + await s2.acquire() + + async with Client(s.address, asynchronous=True): + s3 = pickle.loads(pickled) + assert await s3.acquire() diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 9287e6a424..2e711e9940 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import pickle import random from datetime import timedelta from time import sleep @@ -294,3 +295,33 @@ async def test_variables_do_not_leak_client(c, s, a, b): while set(s.clients) != clients_pre: await asyncio.sleep(0.01) assert time() < start + 5 + + +@gen_cluster(client=True, nthreads=[]) +async def test_unpickle_without_client(c, s): + """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context. + + This typically happens if the object is being deserialized on the scheduler. + """ + obj = Variable("foo") + pickled = pickle.dumps(obj) + await c.close() + + # We do not want to initialize a client during unpickling + with pytest.raises(ValueError): + Client.current() + + obj2 = pickle.loads(pickled) + + with pytest.raises(ValueError): + Client.current() + + assert obj2.client is None + + with pytest.raises(RuntimeError, match="not properly initialized"): + await obj2.set(42) + + async with Client(s.address, asynchronous=True): + obj3 = pickle.loads(pickled) + await obj3.set(42) + assert await obj3.get() == 42 diff --git a/distributed/variable.py b/distributed/variable.py index 801916fb3b..2eb1fe6706 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -10,10 +10,10 @@ from dask.utils import parse_timedelta, stringify -from distributed.client import Client, Future +from distributed.client import Future from distributed.metrics import time from distributed.utils import TimeoutError, log_errors -from distributed.worker import get_client, get_worker +from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -165,14 +165,21 @@ class Variable: Queue: shared multi-producer/multi-consumer queue between clients """ - def __init__(self, name=None, client=None, maxsize=0): + def __init__(self, name=None, client=None): try: - self.client = client or Client.current() + self.client = client or get_client() except ValueError: - # Initialise new client - self.client = get_worker().client + self.client = None self.name = name or "variable-" + uuid.uuid4().hex + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) + async def _set(self, value): if isinstance(value, Future): await self.client.scheduler.variable_set( @@ -189,6 +196,7 @@ def set(self, value, **kwargs): value : Future or object Must be either a Future or a msgpack-encodable value """ + self._verify_running() return self.client.sync(self._set, value, **kwargs) async def _get(self, timeout=None): @@ -221,6 +229,7 @@ def get(self, timeout=None, **kwargs): Instead of number of seconds, it is also possible to specify a timedelta in string format, e.g. "200ms". """ + self._verify_running() timeout = parse_timedelta(timeout) return self.client.sync(self._get, timeout=timeout, **kwargs) @@ -229,17 +238,9 @@ def delete(self): Caution, this affects all clients currently pointing to this variable. """ + self._verify_running() if self.client.status == "running": # TODO: can leave zombie futures self.client._send_to_scheduler({"op": "variable_delete", "name": self.name}) - def __getstate__(self): - return (self.name, self.client.scheduler.address) - - def __setstate__(self, state): - name, address = state - try: - client = get_client(address) - assert client.scheduler.address == address - except (AttributeError, AssertionError): - client = Client(address, set_as_default=False) - self.__init__(name=name, client=client) + def __reduce__(self): + return Variable, (self.name,)