Skip to content

Commit

Permalink
Use worker comm pool in Semaphore (#4195)
Browse files Browse the repository at this point in the history
* Semaphore uses worker comm pool
* Switch semaphore logging to debug level
* Align usage of loop and scheduler attribute names in Semaphore
  • Loading branch information
fjetter authored Jan 20, 2021
1 parent 9442d9b commit e736c0b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 88 deletions.
155 changes: 97 additions & 58 deletions distributed/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from collections import defaultdict, deque

import dask
from tornado.ioloop import PeriodicCallback
from tornado.ioloop import IOLoop, PeriodicCallback

from distributed.utils_comm import retry_operation

from .metrics import time
from .utils import log_errors, parse_timedelta
from .worker import get_client
from .utils import log_errors, parse_timedelta, sync, thread_state
from .worker import get_client, get_worker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,7 +131,7 @@ def _get_lease(self, name, lease_id):
or len(self.leases[name]) < self.max_leases[name]
):
now = time()
logger.info("Acquire lease %s for %s at %s", lease_id, name, now)
logger.debug("Acquire lease %s for %s at %s", lease_id, name, now)
self.leases[name][lease_id] = now
self.metrics["acquire_total"][name] += 1
else:
Expand All @@ -154,8 +155,8 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None):

self.metrics["pending"][name] += 1
while True:
logger.info(
"Trying to acquire %s for %s with %ss left.",
logger.debug(
"Trying to acquire %s for %s with %s seconds left.",
lease_id,
name,
w.leftover(),
Expand All @@ -177,7 +178,7 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None):
continue
except TimeoutError:
result = False
logger.info(
logger.debug(
"Acquisition of lease %s for %s is %s after waiting for %ss.",
lease_id,
name,
Expand Down Expand Up @@ -210,7 +211,7 @@ def release(self, comm=None, name=None, lease_id=None):
)

def _release_value(self, name, lease_id):
logger.info("Releasing %s for %s", lease_id, name)
logger.debug("Releasing %s for %s", lease_id, name)
# Everything needs to be atomic here.
del self.leases[name][lease_id]
self.events[name].set()
Expand All @@ -230,7 +231,7 @@ def _check_lease_timeout(self):
for _id in ids:
time_since_refresh = now - self.leases[name][_id]
if time_since_refresh > self.lease_timeout:
logger.info(
logger.debug(
"Lease %s for %s timed out after %ss.",
_id,
name,
Expand Down Expand Up @@ -311,15 +312,19 @@ class Semaphore:
Name of the semaphore to acquire. Choosing the same name allows two
disconnected processes to coordinate. If not given, a random
name will be generated.
client: Client (optional)
Client to use for communication with the scheduler. If not given, the
default global client will be used.
register: bool
If True, register the semaphore with the scheduler. This needs to be
done before any leases can be acquired. If not done during
initialization, this can also be done by calling the register method of
this class.
When registering, this needs to be awaited.
scheduler_rpc: ConnectionPool
The ConnectionPool to connect to the scheduler. If None is provided, it
uses the worker or client pool. This paramter is mostly used for
testing.
loop: IOLoop
The event loop this instance is using. If None is provided, reuse the
loop of the active worker or client.
Examples
--------
Expand Down Expand Up @@ -355,8 +360,25 @@ class Semaphore:
"""

def __init__(self, max_leases=1, name=None, client=None, register=True):
self.client = client or get_client()
def __init__(
self,
max_leases=1,
name=None,
register=True,
scheduler_rpc=None,
loop=None,
):

try:
worker = get_worker()
self.scheduler = scheduler_rpc or worker.scheduler
self.loop = loop or worker.loop

except ValueError:
client = get_client()
self.scheduler = scheduler_rpc or client.scheduler
self.loop = loop or client.io_loop

self.name = name or "semaphore-" + uuid.uuid4().hex
self.max_leases = max_leases
self.id = uuid.uuid4().hex
Expand All @@ -381,27 +403,25 @@ def __init__(self, max_leases=1, name=None, client=None, register=True):
self._refresh_leases, callback_time=refresh_leases_interval * 1000
)
self.refresh_callback = pc
# Registering the pc to the client here is important for proper cleanup
self._periodic_callback_name = f"refresh_semaphores_{self.id}"
self.client._periodic_callbacks[self._periodic_callback_name] = pc

# Need to start the callback using IOLoop.add_callback to ensure that the
# PC uses the correct event loop.
self.client.io_loop.add_callback(pc.start)
self.loop.add_callback(pc.start)

def register(self):
"""
Register the semaphore on scheduler side
@property
def asynchronous(self):
return self.loop is IOLoop.current()

This will register the semaphore on scheduler side and ensure that all necessary data structures exist.
"""
if self._registered is None:
self._registered = self.client.sync(
self.client.scheduler.semaphore_register,
name=self.name,
max_leases=self.max_leases,
)
return self._registered
async def _register(self):
await retry_operation(
self.scheduler.semaphore_register,
name=self.name,
max_leases=self.max_leases,
operation=f"semaphore register id={self.id} name={self.name}",
)

def register(self, **kwargs):
return self.sync(self._register)

def __await__(self):
async def create_semaphore():
Expand All @@ -411,34 +431,53 @@ async def create_semaphore():

return create_semaphore().__await__()

def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs):
callback_timeout = parse_timedelta(callback_timeout)
if (
asynchronous
or self.asynchronous
or getattr(thread_state, "asynchronous", False)
):
future = func(*args, **kwargs)
if callback_timeout is not None:
future = asyncio.wait_for(future, callback_timeout)
return future
else:
return sync(
self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
)

async def _refresh_leases(self):
if self.refresh_leases and self._leases:
logger.debug(
"%s refreshing leases for %s with IDs %s",
self.client.id,
self.id,
self.name,
self._leases,
)
await self.client.scheduler.semaphore_refresh_leases(
lease_ids=list(self._leases), name=self.name
await retry_operation(
self.scheduler.semaphore_refresh_leases,
lease_ids=list(self._leases),
name=self.name,
operation="semaphore refresh leases: id=%s, lease_ids=%s, name=%s"
% (self.id, list(self._leases), self.name),
)

async def _acquire(self, timeout=None):
lease_id = uuid.uuid4().hex
logger.info(
"%s requests lease for %s with ID %s", self.client.id, self.name, lease_id
logger.debug(
"%s requests lease for %s with ID %s", self.id, self.name, lease_id
)

# Using a unique lease id generated here allows us to retry since the
# server handle is idempotent

result = await retry_operation(
self.client.scheduler.semaphore_acquire,
self.scheduler.semaphore_acquire,
name=self.name,
timeout=timeout,
lease_id=lease_id,
operation="semaphore acquire: client=%s, lease_id=%s, name=%s"
% (self.client.id, lease_id, self.name),
operation="semaphore acquire: id=%s, lease_id=%s, name=%s"
% (self.id, lease_id, self.name),
)
if result:
self._leases.append(lease_id)
Expand All @@ -460,26 +499,22 @@ def acquire(self, timeout=None):
a timedelta in string format, e.g. "200ms".
"""
timeout = parse_timedelta(timeout)
return self.client.sync(self._acquire, timeout=timeout)

async def _release(self):
# popleft to release the oldest lease first
lease_id = self._leases.popleft()
logger.info("%s releases %s for %s", self.client.id, lease_id, self.name)
return self.sync(self._acquire, timeout=timeout)

async def _release(self, lease_id):
try:
await retry_operation(
self.client.scheduler.semaphore_release,
self.scheduler.semaphore_release,
name=self.name,
lease_id=lease_id,
operation="semaphore release: client=%s, lease_id=%s, name=%s"
% (self.client.id, lease_id, self.name),
operation="semaphore release: id=%s, lease_id=%s, name=%s"
% (self.id, lease_id, self.name),
)
return True
except Exception: # Release fails for whatever reason
logger.error(
"Release failed for client=%s, lease_id=%s, name=%s. Cluster network might be unstable?"
% (self.client.id, lease_id, self.name),
"Release failed for id=%s, lease_id=%s, name=%s. Cluster network might be unstable?"
% (self.id, lease_id, self.name),
exc_info=True,
)
return False
Expand All @@ -499,13 +534,16 @@ def release(self):
if not self._leases:
raise RuntimeError("Released too often")

return self.client.sync(self._release)
# popleft to release the oldest lease first
lease_id = self._leases.popleft()
logger.debug("%s releases %s for %s", self.id, lease_id, self.name)
return self.sync(self._release, lease_id=lease_id)

def get_value(self):
"""
Return the number of currently registered leases.
"""
return self.client.sync(self.client.scheduler.semaphore_value, name=self.name)
return self.sync(self.scheduler.semaphore_value, name=self.name)

def __enter__(self):
self.acquire()
Expand All @@ -528,13 +566,14 @@ def __getstate__(self):

def __setstate__(self, state):
name, max_leases = state
client = get_client()
self.__init__(name=name, client=client, max_leases=max_leases, register=False)
self.__init__(
name=name,
max_leases=max_leases,
register=False,
)

def close(self):
return self.client.sync(self.client.scheduler.semaphore_close, name=self.name)
return self.sync(self.scheduler.semaphore_close, name=self.name)

def __del__(self):
if self._periodic_callback_name in self.client._periodic_callbacks:
self.client._periodic_callbacks[self._periodic_callback_name].stop()
del self.client._periodic_callbacks[self._periodic_callback_name]
self.refresh_callback.stop()
Loading

0 comments on commit e736c0b

Please sign in to comment.