From f8b2cb9b2af86352a7b4d43fa07d777d62e35639 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 16 Jul 2024 14:00:33 +0200 Subject: [PATCH 1/2] Use Semaphore backend for lock --- distributed/lock.py | 161 +++++++--------------------- distributed/scheduler.py | 2 - distributed/semaphore.py | 102 +++++++++--------- distributed/tests/test_locks.py | 43 ++++---- distributed/tests/test_semaphore.py | 24 ++--- 5 files changed, 121 insertions(+), 211 deletions(-) diff --git a/distributed/lock.py b/distributed/lock.py index 99ec34cd6f7..85d3d8ee8a8 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -1,79 +1,31 @@ from __future__ import annotations -import asyncio import logging -import uuid -from collections import defaultdict, deque -from dask.utils import parse_timedelta - -from distributed.utils import TimeoutError, log_errors, wait_for -from distributed.worker import get_client +from distributed.semaphore import Semaphore logger = logging.getLogger(__name__) -class LockExtension: - """An extension for the scheduler to manage Locks +class Lock(Semaphore): + """Distributed Centralized Lock - This adds the following routes to the scheduler + .. warning:: - * lock_acquire - * lock_release - """ + This is using the ``distributed.Semaphore`` as a backend, which is + susceptible to lease overbooking. For the Lock this means that if a + lease is timing out, two or more instances could acquire the lock at the + same time. To disable lease timeouts, set + ``distributed.scheduler.locks.lease-timeout`` to `inf`, e.g. - def __init__(self, scheduler): - self.scheduler = scheduler - self.events = defaultdict(deque) - self.ids = dict() + .. code-block:: python - self.scheduler.handlers.update( - {"lock_acquire": self.acquire, "lock_release": self.release} - ) + with dask.config.set({"distributed.scheduler.locks.lease-timeout": "inf"}): + lock = Lock("x") + ... - @log_errors - async def acquire(self, name=None, id=None, timeout=None): - if isinstance(name, list): - name = tuple(name) - if name not in self.ids: - result = True - else: - while name in self.ids: - event = asyncio.Event() - self.events[name].append(event) - future = event.wait() - if timeout is not None: - future = wait_for(future, timeout) - try: - await future - except TimeoutError: - result = False - break - else: - result = True - finally: - event2 = self.events[name].popleft() - assert event is event2 - if result: - assert name not in self.ids - self.ids[name] = id - return result - - @log_errors - def release(self, name=None, id=None): - if isinstance(name, list): - name = tuple(name) - if self.ids.get(name) != id: - raise ValueError("This lock has not yet been acquired") - del self.ids[name] - if self.events[name]: - self.scheduler.loop.add_callback(self.events[name][0].set) - else: - del self.events[name] - - -class Lock: - """Distributed Centralized Lock + Note, that without lease timeouts, the Lock may deadlock in case of + cluster downscaling or worker failures. Parameters ---------- @@ -93,28 +45,20 @@ class Lock: >>> lock.release() # doctest: +SKIP """ - def __init__(self, name=None, client=None): - self._client = client - self.name = name or "lock-" + uuid.uuid4().hex - self.id = uuid.uuid4().hex - self._locked = False - - @property - def client(self): - if not self._client: - try: - self._client = get_client() - except ValueError: - pass - return self._client - - def _verify_running(self): - if not self.client: - raise RuntimeError( - f"{type(self)} object not properly initialized. This can happen" - " if the object is being deserialized outside of the context of" - " a Client or Worker." - ) + def __init__( + self, + name=None, + register=True, + scheduler_rpc=None, + loop=None, + ): + super().__init__( + max_leases=1, + name=name, + register=register, + scheduler_rpc=scheduler_rpc, + loop=loop, + ) def acquire(self, blocking=True, timeout=None): """Acquire the lock @@ -139,50 +83,21 @@ def acquire(self, blocking=True, timeout=None): ------- True or False whether or not it successfully acquired the lock """ - self._verify_running() - timeout = parse_timedelta(timeout) - if not blocking: if timeout is not None: raise ValueError("can't specify a timeout for a non-blocking call") timeout = 0 + return super().acquire(timeout=timeout) - result = self.client.sync( - self.client.scheduler.lock_acquire, - name=self.name, - id=self.id, - timeout=timeout, - ) - self._locked = True - return result - - def release(self): - """Release the lock if already acquired""" - self._verify_running() - if not self.locked(): - raise ValueError("Lock is not yet acquired") - result = self.client.sync( - self.client.scheduler.lock_release, name=self.name, id=self.id - ) - self._locked = False - return result + async def _locked(self): + val = await self.scheduler.semaphore_value(name=self.name) + return val == 1 def locked(self): - return self._locked - - def __enter__(self): - self.acquire() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.release() - - async def __aenter__(self): - await self.acquire() - return self + return self.sync(self._locked) - async def __aexit__(self, exc_type, exc_value, traceback): - await self.release() + def __getstate__(self): + return self.name - def __reduce__(self): - return (Lock, (self.name,)) + def __setstate__(self, state): + self.__init__(name=state, register=False) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0273d333da3..13b7a237348 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -101,7 +101,6 @@ from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension from distributed.http import get_handlers -from distributed.lock import LockExtension from distributed.metrics import time from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode @@ -179,7 +178,6 @@ STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { - "locks": LockExtension, "multi_locks": MultiLockExtension, "publish": PublishExtension, "replay-tasks": ReplayTaskScheduler, diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 59d6951d7b0..9091e9a2654 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -40,7 +40,7 @@ def __init__(self, scheduler): self.max_leases = dict() # {semaphore_name: {lease_id: lease_last_seen_timestamp}} self.leases = defaultdict(dict) - + self.lease_timeouts = dict() self.scheduler.handlers.update( { "semaphore_register": self.create, @@ -70,20 +70,18 @@ def __init__(self, scheduler): self._check_lease_timeout, validation_callback_time * 1000 ) pc.start() - self.lease_timeout = parse_timedelta( - dask.config.get("distributed.scheduler.locks.lease-timeout"), default="s" - ) def get_value(self, name=None): return len(self.leases[name]) # `comm` here is required by the handler interface - def create(self, name=None, max_leases=None): + def create(self, name, max_leases, lease_timeout): # We use `self.max_leases` as the point of truth to find out if a semaphore with a specific # `name` has been created. if name not in self.max_leases: assert isinstance(max_leases, int), max_leases self.max_leases[name] = max_leases + self.lease_timeouts[name] = lease_timeout else: if max_leases != self.max_leases[name]: raise ValueError( @@ -128,7 +126,7 @@ def _semaphore_exists(self, name): @log_errors async def acquire(self, name=None, timeout=None, lease_id=None): if not self._semaphore_exists(name): - raise RuntimeError(f"Semaphore `{name}` not known or already closed.") + raise RuntimeError(f"Semaphore or Lock `{name}` not known.") if isinstance(name, list): name = tuple(name) @@ -176,7 +174,7 @@ async def acquire(self, name=None, timeout=None, lease_id=None): def release(self, name=None, lease_id=None): if not self._semaphore_exists(name): logger.warning( - f"Tried to release semaphore `{name}` but it is not known or already closed." + f"Tried to release Lock or Semaphore `{name}` but it is not known." ) return if isinstance(name, list): @@ -185,9 +183,9 @@ def release(self, name=None, lease_id=None): self._release_value(name, lease_id) else: logger.warning( - "Tried to release semaphore but it was already released: " + f"Tried to release Lock or Semaphore but it was already released: " f"{name=}, {lease_id=}. " - "This can happen if the semaphore timed out before." + f"This can happen if the Lock or Semaphore timed out before." ) def _release_value(self, name, lease_id): @@ -201,23 +199,24 @@ def _check_lease_timeout(self): now = time() semaphore_names = list(self.leases.keys()) for name in semaphore_names: - ids = list(self.leases[name]) - logger.debug( - "Validating leases for %s at time %s. Currently known %s", - name, - now, - self.leases[name], - ) - for _id in ids: - time_since_refresh = now - self.leases[name][_id] - if time_since_refresh > self.lease_timeout: - logger.debug( - "Lease %s for %s timed out after %ss.", - _id, - name, - time_since_refresh, - ) - self._release_value(name=name, lease_id=_id) + if lease_timeout := self.lease_timeouts.get(name): + ids = list(self.leases[name]) + logger.debug( + "Validating leases for %s at time %s. Currently known %s", + name, + now, + self.leases[name], + ) + for _id in ids: + time_since_refresh = now - self.leases[name][_id] + if time_since_refresh > lease_timeout: + logger.debug( + "Lease %s for %s timed out after %ss.", + _id, + name, + time_since_refresh, + ) + self._release_value(name=name, lease_id=_id) @log_errors def close(self, name=None): @@ -226,6 +225,7 @@ def close(self, name=None): return del self.max_leases[name] + del self.lease_timeouts[name] if name in self.events: del self.events[name] if name in self.leases: @@ -320,14 +320,6 @@ class Semaphore(SyncMethodMixin): ----- If a client attempts to release the semaphore but doesn't have a lease acquired, this will raise an exception. - - When a semaphore is closed, if, for that closed semaphore, a client attempts to: - - - Acquire a lease: an exception will be raised. - - Release: a warning will be logged. - - Close: nothing will happen. - - dask executes functions by default assuming they are pure, when using semaphore acquire/releases inside such a function, it must be noted that there *are* in fact side-effects, thus, the function can no longer be considered pure. If this is not taken into account, this may lead to unexpected behavior. @@ -352,28 +344,28 @@ def __init__( self.refresh_leases = True - self._registered = None + self._do_register = None if register: - self._registered = self.register() + self._do_register = register # this should give ample time to refresh without introducing another # config parameter since this *must* be smaller than the timeout anyhow - refresh_leases_interval = ( - parse_timedelta( + lease_timeout = dask.config.get("distributed.scheduler.locks.lease-timeout") + if lease_timeout != "inf": + lease_timeout = parse_timedelta( dask.config.get("distributed.scheduler.locks.lease-timeout"), default="s", ) - / 5 - ) - pc = PeriodicCallback( - self._refresh_leases, callback_time=refresh_leases_interval * 1000 - ) - self.refresh_callback = pc + refresh_leases_interval = lease_timeout / 5 + pc = PeriodicCallback( + self._refresh_leases, callback_time=refresh_leases_interval * 1000 + ) + self.refresh_callback = pc - # Need to start the callback using IOLoop.add_callback to ensure that the - # PC uses the correct event loop. - if self.loop is not None: - self.loop.add_callback(pc.start) + # Need to start the callback using IOLoop.add_callback to ensure that the + # PC uses the correct event loop. + if self.loop is not None: + self.loop.add_callback(pc.start) @property def scheduler(self): @@ -407,10 +399,17 @@ def _verify_running(self): ) async def _register(self): + lease_timeout = dask.config.get("distributed.scheduler.locks.lease-timeout") + + if lease_timeout == "inf": + lease_timeout = None + else: + lease_timeout = parse_timedelta(lease_timeout, "s") await retry_operation( self.scheduler.semaphore_register, name=self.name, max_leases=self.max_leases, + lease_timeout=lease_timeout, operation=f"semaphore register id={self.id} name={self.name}", ) @@ -419,8 +418,8 @@ def register(self, **kwargs): def __await__(self): async def create_semaphore(): - if self._registered: - await self._registered + if self._do_register: + await self._register() return self return create_semaphore().__await__() @@ -442,6 +441,7 @@ async def _refresh_leases(self): ) async def _acquire(self, timeout=None): + await self lease_id = uuid.uuid4().hex logger.debug( "%s requests lease for %s with ID %s", self.id, self.name, lease_id @@ -527,6 +527,7 @@ def get_value(self): return self.sync(self.scheduler.semaphore_value, name=self.name) def __enter__(self): + self.register() self._verify_running() self.acquire() return self @@ -535,6 +536,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.release() async def __aenter__(self): + await self self._verify_running() await self.acquire() return self diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 7477f354602..d13d14577c3 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -6,37 +6,35 @@ import pytest +import dask + from distributed import Lock, get_client from distributed.metrics import time from distributed.utils_test import gen_cluster -@gen_cluster(client=True, nthreads=[("127.0.0.1", 8)] * 2) +@gen_cluster(client=True, nthreads=[("", 8)] * 2) async def test_lock(c, s, a, b): await c.set_metadata("locked", False) def f(x): client = get_client() - with Lock("x") as lock: + with Lock("x"): assert client.get_metadata("locked") is False client.set_metadata("locked", True) - sleep(0.05) + sleep(0.01) assert client.get_metadata("locked") is True client.set_metadata("locked", False) futures = c.map(f, range(20)) await c.gather(futures) - assert not s.extensions["locks"].events - assert not s.extensions["locks"].ids @gen_cluster(client=True) async def test_timeout(c, s, a, b): - locks = s.extensions["locks"] lock = Lock("x") result = await lock.acquire() assert result is True - assert locks.ids["x"] == lock.id lock2 = Lock("x") assert lock.id != lock2.id @@ -46,9 +44,6 @@ async def test_timeout(c, s, a, b): stop = time() assert stop - start < 0.3 assert result is False - assert locks.ids["x"] == lock.id - assert not locks.events["x"] - await lock.release() @@ -56,7 +51,7 @@ async def test_timeout(c, s, a, b): async def test_acquires_with_zero_timeout(c, s, a, b): lock = Lock("x") await lock.acquire(timeout=0) - assert lock.locked() + assert await lock.locked() await lock.release() await lock.acquire(timeout="1s") @@ -69,12 +64,12 @@ async def test_acquires_with_zero_timeout(c, s, a, b): async def test_acquires_blocking(c, s, a, b): lock = Lock("x") await lock.acquire(blocking=False) - assert lock.locked() + assert await lock.locked() await lock.release() - assert not lock.locked() + assert not await lock.locked() with pytest.raises(ValueError): - lock.acquire(blocking=False, timeout=1) + lock.acquire(blocking=False, timeout=0.1) def test_timeout_sync(client): @@ -85,7 +80,7 @@ def test_timeout_sync(client): @gen_cluster(client=True) async def test_errors(c, s, a, b): lock = Lock("x") - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): await lock.release() @@ -95,7 +90,7 @@ def f(x): client = get_client() assert client.get_metadata("locked") is False client.set_metadata("locked", True) - sleep(0.05) + sleep(0.01) assert client.get_metadata("locked") is True client.set_metadata("locked", False) @@ -113,8 +108,6 @@ async def test_lock_types(c, s, a, b): await lock.acquire() await lock.release() - assert not s.extensions["locks"].events - @gen_cluster(client=True) async def test_serializable(c, s, a, b): @@ -129,13 +122,21 @@ def f(x, lock=None): lock2 = pickle.loads(pickle.dumps(lock)) assert lock2.name == lock.name - assert lock2.client is lock.client @gen_cluster(client=True, nthreads=[]) async def test_locks(c, s): async with Lock("x") as l1: l2 = Lock("x") - assert l1.client is c - assert l2.client is c assert await l2.acquire(timeout=0.01) is False + + +@gen_cluster(client=True, nthreads=[]) +async def test_locks_inf_lease_timeout(c, s): + sem_ext = s.extensions["semaphores"] + async with Lock("x"): + assert sem_ext.lease_timeouts["x"] + + with dask.config.set({"distributed.scheduler.locks.lease-timeout": "inf"}): + async with Lock("y"): + assert sem_ext.lease_timeouts.get("y") is None diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index af759832a0b..16a07362b5e 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -100,8 +100,8 @@ def test_timeout_sync(client): @gen_cluster( client=True, config={ - "distributed.scheduler.locks.lease-validation-interval": "200ms", - "distributed.scheduler.locks.lease-timeout": "200ms", + "distributed.scheduler.locks.lease-validation-interval": "100ms", + "distributed.scheduler.locks.lease-timeout": "100ms", }, ) async def test_release_semaphore_after_timeout(c, s, a, b): @@ -199,14 +199,16 @@ async def test_close_async(c, s, a): assert await sem.acquire() with pytest.warns( RuntimeWarning, - match="Closing semaphore .* but there remain unreleased leases .*", + match="Closing semaphore test but there remain unreleased leases .*", ): await sem.close() - - with pytest.raises( - RuntimeError, match="Semaphore `test` not known or already closed." + # After close, the semaphore is reset + await sem.acquire() + with pytest.warns( + RuntimeWarning, + match="Closing semaphore test but there remain unreleased leases .*", ): - await sem.acquire() + await sem.close() sem2 = await Semaphore(name="t2", max_leases=1) assert await sem2.acquire() @@ -231,14 +233,6 @@ def f(sem_): assert not metric_dict -def test_close_sync(client): - sem = Semaphore() - sem.close() - - with pytest.raises(RuntimeError, match="Semaphore .* not known or already closed."): - sem.acquire() - - @gen_cluster(client=True) async def test_release_once_too_many(c, s, a, b): sem = await Semaphore(name="x") From c5707072ada61654850ef04915fed5bf85ac11ff Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 17 Jul 2024 13:47:13 +0200 Subject: [PATCH 2/2] add deprecation warning --- distributed/lock.py | 12 ++++++++++++ distributed/semaphore.py | 32 ++++++++++++++++++-------------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/distributed/lock.py b/distributed/lock.py index 85d3d8ee8a8..4c79303bd2e 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -6,6 +6,8 @@ logger = logging.getLogger(__name__) +_no_value = object() + class Lock(Semaphore): """Distributed Centralized Lock @@ -48,10 +50,20 @@ class Lock(Semaphore): def __init__( self, name=None, + client=_no_value, register=True, scheduler_rpc=None, loop=None, ): + if client is not _no_value: + import warnings + + warnings.warn( + "The `client` parameter is deprecated. It is no longer necessary to pass a client to Lock.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( max_leases=1, name=name, diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 9091e9a2654..fd971d46e8c 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -351,21 +351,25 @@ def __init__( # this should give ample time to refresh without introducing another # config parameter since this *must* be smaller than the timeout anyhow lease_timeout = dask.config.get("distributed.scheduler.locks.lease-timeout") - if lease_timeout != "inf": - lease_timeout = parse_timedelta( - dask.config.get("distributed.scheduler.locks.lease-timeout"), - default="s", - ) - refresh_leases_interval = lease_timeout / 5 - pc = PeriodicCallback( - self._refresh_leases, callback_time=refresh_leases_interval * 1000 - ) - self.refresh_callback = pc + if lease_timeout == "inf": + return + + ## Below is all code for the lease timout validation + + lease_timeout = parse_timedelta( + dask.config.get("distributed.scheduler.locks.lease-timeout"), + default="s", + ) + refresh_leases_interval = lease_timeout / 5 + pc = PeriodicCallback( + self._refresh_leases, callback_time=refresh_leases_interval * 1000 + ) + self.refresh_callback = pc - # Need to start the callback using IOLoop.add_callback to ensure that the - # PC uses the correct event loop. - if self.loop is not None: - self.loop.add_callback(pc.start) + # Need to start the callback using IOLoop.add_callback to ensure that the + # PC uses the correct event loop. + if self.loop is not None: + self.loop.add_callback(pc.start) @property def scheduler(self):